From 404d6b69e3815c198477537a43d8883c79cfb570 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 4 Feb 2025 13:09:29 -0800 Subject: [PATCH 01/32] repro --- tests/cpp/test_rope.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_rope.cpp b/tests/cpp/test_rope.cpp index 74473654eca..9fab54d8d83 100644 --- a/tests/cpp/test_rope.cpp +++ b/tests/cpp/test_rope.cpp @@ -49,7 +49,7 @@ struct RopeConfig { class RopeTest : public NVFuserFixtureParamTest { protected: void SetUp() override { - EnableOptionsGuard::getCurOptions().set(EnableOption::ResizeScheduler); + // EnableOptionsGuard::getCurOptions().set(EnableOption::ResizeScheduler); NVFuserTest::SetUp(); } }; From 1b8fc0fc0825f9540904d1b096367afd23b47ebc Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 4 Feb 2025 20:08:16 -0800 Subject: [PATCH 02/32] renable resize sched --- tests/cpp/test_rope.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_rope.cpp b/tests/cpp/test_rope.cpp index 9fab54d8d83..74473654eca 100644 --- a/tests/cpp/test_rope.cpp +++ b/tests/cpp/test_rope.cpp @@ -49,7 +49,7 @@ struct RopeConfig { class RopeTest : public NVFuserFixtureParamTest { protected: void SetUp() override { - // EnableOptionsGuard::getCurOptions().set(EnableOption::ResizeScheduler); + EnableOptionsGuard::getCurOptions().set(EnableOption::ResizeScheduler); NVFuserTest::SetUp(); } }; From 861c187d65f016a274f440253fba4c575944ec6e Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 4 Feb 2025 20:35:07 -0800 Subject: [PATCH 03/32] cleanup --- csrc/scheduler/tools/loop_domain_scheduler.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/csrc/scheduler/tools/loop_domain_scheduler.cpp b/csrc/scheduler/tools/loop_domain_scheduler.cpp index 16734cfb294..b54cbb0b3e6 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.cpp +++ b/csrc/scheduler/tools/loop_domain_scheduler.cpp @@ -99,6 +99,8 @@ class LoopDomainSchedulerReplayTransform : OptInConstDispatch { const std::vector& output_ids_; }; +// Replay a given IterDomain transform expression on the loop domain +// of a given tensor using specified loop IDs as its inputs. class ReplayForwardTransformOnLoopDomain : OptInConstDispatch { public: static void replayAs( @@ -541,7 +543,11 @@ void scheduleLoopDomainsBy( // When the direction is forward, the TensorView transform // APIs, e.g., TensorView::split, can be used, which doesn't need - // to use TensorView::setLoopDomain. + // to use TensorView::setLoopDomain. This is important as + // setLoopDomain may result in losing extra IDs added by prior + // scheduleLoopDomain calls, which was indeed the case with the + // Llama 3 RoPE backward (see also + // https://github.com/NVIDIA/Fuser/issues/3571). if (replay_dir_tv == Direction::Forward) { ReplayForwardTransformOnLoopDomain::replayAs(tv, input_ids, transform); continue; From 12d0e852d015e8ce34f343298201716765d668f3 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 6 Feb 2025 19:01:56 -0800 Subject: [PATCH 04/32] debug print --- tests/cpp/test_rope.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/cpp/test_rope.cpp b/tests/cpp/test_rope.cpp index 74473654eca..aa71d874ae3 100644 --- a/tests/cpp/test_rope.cpp +++ b/tests/cpp/test_rope.cpp @@ -1204,8 +1204,6 @@ TEST_P(Phi3RopeTest, Bwd) { auto T189 = castOp(DataType::BFloat16, T188); fusion.addOutput(T189); - fusion.print(); - auto options_bf16 = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); From 643779d08748448798f08c080af9af7e12401617 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 6 Feb 2025 19:02:37 -0800 Subject: [PATCH 05/32] enable resize scheduler --- csrc/scheduler/resize.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/scheduler/resize.cpp b/csrc/scheduler/resize.cpp index 1c198aee8dc..e066d5a2420 100644 --- a/csrc/scheduler/resize.cpp +++ b/csrc/scheduler/resize.cpp @@ -62,11 +62,13 @@ std::pair getLargestTensor( } // namespace bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) { +#if 0 if (!isOptionEnabled(EnableOption::ResizeScheduler)) { scheduler_debug_utils::canScheduleRejectReason( schedulerType(), "Not enabled"); return false; } +#endif if (!scheduler_tools::hasResizeBasedOps(fusion)) { scheduler_debug_utils::canScheduleRejectReason( From 4f28c441f2ab9c879727ec7294470d6902e93ad1 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 7 Feb 2025 09:06:57 -0800 Subject: [PATCH 06/32] fix --- tests/cpp/test_move_pad.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/cpp/test_move_pad.cpp b/tests/cpp/test_move_pad.cpp index aa57dfec770..deebb7fd2e6 100644 --- a/tests/cpp/test_move_pad.cpp +++ b/tests/cpp/test_move_pad.cpp @@ -116,8 +116,7 @@ TEST_F(MovePadTest, BinaryBroadcastOnNonCatDim) { EXPECT_THAT( runtime->fusionSegments()->groups(), UnorderedElementsAre( - HeuristicIs(SchedulerType::NoOp), - HeuristicIs(SchedulerType::PointWise))); + HeuristicIs(SchedulerType::Resize))); testValidate( executor_cache.fusion(), out_tensors, aten_inputs, __LINE__, __FILE__); From 5f851ea35b0082d4efea361e5c9459c818666eef Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 7 Feb 2025 21:23:38 -0800 Subject: [PATCH 07/32] fix --- csrc/scheduler/resize.cpp | 7 +++++-- csrc/scheduler/tools/loop_domain_scheduler.cpp | 13 +++++++++++++ csrc/scheduler/tools/loop_domain_scheduler.h | 3 +++ tests/cpp/test_resize.cpp | 4 ++-- 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/csrc/scheduler/resize.cpp b/csrc/scheduler/resize.cpp index e066d5a2420..4b1d892e4e6 100644 --- a/csrc/scheduler/resize.cpp +++ b/csrc/scheduler/resize.cpp @@ -239,7 +239,7 @@ std::unique_ptr ResizeScheduler::computeHeuristics( // Before applying the vectorization split, any reshape transform of // the largest input will be cancelled whenever possible, so the // largest input is used as the reference of vectorization. - auto vec_ref_tv = largest_input != nullptr ? largest_input : ref_tv; + auto vec_ref_tv = ref_tv; // Only consider the innermost dimension to vectorize for now. // TODO: Consider vectorizing merged IDs, not just the innermost @@ -478,9 +478,12 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { } if (vec_factor > 1) { - auto vec_ref_tv = largest_input != nullptr ? largest_input : ref_tv; + // auto vec_ref_tv = largest_input != nullptr ? largest_input : + // ref_tv; + auto vec_ref_tv = ref_tv; const auto tvs_to_vectorize = scheduler_utils::getInputsOutputsWithInnerDim(vec_ref_tv, true, true); + std::cerr << "TVs to vec: " << toDelimitedString(tvs_to_vectorize) << "\n"; for (auto tv_to_vectorize : tvs_to_vectorize) { if (tv_to_vectorize->isFusionInput()) { for (auto consumer_tv : ir_utils::consumerTvsOf(tv_to_vectorize)) { diff --git a/csrc/scheduler/tools/loop_domain_scheduler.cpp b/csrc/scheduler/tools/loop_domain_scheduler.cpp index b54cbb0b3e6..2b8fddde023 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.cpp +++ b/csrc/scheduler/tools/loop_domain_scheduler.cpp @@ -683,6 +683,15 @@ void cancelReshapeInLoopDomains(TensorView* from_tv) { {reshape_out->getLogicalDomain().begin(), reshape_out->getLogicalDomain().end()}); + auto reshape_exprs_with_innermost_logical_id = + DependencyCheck::getAllExprsBetween( + {reshape_out->getRootDomain().begin(), + reshape_out->getRootDomain().end()}, + {reshape_out->getLogicalDomain().back()}); + std::unordered_set reshape_exprs_with_innermost_logical_id_set = { + reshape_exprs_with_innermost_logical_id.begin(), + reshape_exprs_with_innermost_logical_id.end()}; + auto reshape_out_loop_domain = reshape_out->getLoopDomain(); for (auto reshape_exprs_it = reshape_exprs.rbegin(); @@ -690,6 +699,10 @@ void cancelReshapeInLoopDomains(TensorView* from_tv) { ++reshape_exprs_it) { auto reshape_expr = *reshape_exprs_it; + if (reshape_exprs_with_innermost_logical_id_set.count(reshape_expr)) { + continue; + } + // If any of the output IDs of reshape_expr is not found in // cancellable_ids, that means the expr cannot be cancelled. if (std::any_of( diff --git a/csrc/scheduler/tools/loop_domain_scheduler.h b/csrc/scheduler/tools/loop_domain_scheduler.h index fa0d4e0d2ae..d8fc635bc95 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.h +++ b/csrc/scheduler/tools/loop_domain_scheduler.h @@ -109,5 +109,8 @@ void scheduleLoopDomainsBy( // but that is not currently supported. void cancelReshapeInLoopDomains(TensorView* from_tv); +std::optional getInnermostCancelableReshapePosition( + TensorView* from_tv); + } // namespace scheduler_tools } // namespace nvfuser diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index bbf26c5e49f..eb6a95c3a5a 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -4873,7 +4873,7 @@ TEST_P(ResizeSchedulerTest, SliceRotateCat) { Fusion& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); - std::vector shape({-1, 100}); + std::vector shape({-1, 128}); EnableOptionsGuard enable_options_guard; EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); @@ -4899,7 +4899,7 @@ TEST_P(ResizeSchedulerTest, SliceRotateCat) { auto tv5 = cat({tv4, tv2}, 1); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({16, 100}, options); + auto t0 = at::randn({16, 128}, options); std::vector inputs({t0}); fusion.addOutput(tv5); From d7c9af139889b03395761c91cbc5dfe7420e2bce Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 7 Feb 2025 21:38:30 -0800 Subject: [PATCH 08/32] cleanup --- csrc/options.cpp | 2 +- csrc/options.h | 2 +- csrc/preseg_passes/pre_segmenter.cpp | 2 +- csrc/scheduler/resize.cpp | 7 ++----- tests/cpp/test_move_pad.cpp | 11 +++++++++-- tests/cpp/test_resize.cpp | 20 ++++---------------- tests/cpp/test_rope.cpp | 8 +------- 7 files changed, 19 insertions(+), 33 deletions(-) diff --git a/csrc/options.cpp b/csrc/options.cpp index de1ee6cbcd8..01049cd3d9b 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -163,7 +163,6 @@ const std::unordered_map& getEnableOptions() { {"kernel_profile", EnableOption::KernelProfile}, {"memory_promotion", EnableOption::MemoryPromotion}, {"reuse_zeroed_memory", EnableOption::ReuseZeroedMemory}, - {"resize_scheduler", EnableOption::ResizeScheduler}, {"static_fusion_count", EnableOption::StaticFusionCount}, {"wait_debugger", EnableOption::WaitDebugger}, {"warn_register_spill", EnableOption::WarnRegisterSpill}, @@ -210,6 +209,7 @@ const std::unordered_map& getDisableOptions() { {"kernel_reuse", DisableOption::KernelReuse}, {"var_name_remapping", DisableOption::VarNameRemapping}, {"welford_vectorization", DisableOption::WelfordVectorization}, + {"resize_scheduler", DisableOption::ResizeScheduler}, {"reuse_mismatched_type_registers", DisableOption::ReuseMismatchedTypeRegisters}, {"multidevice", DisableOption::Multidevice}}; diff --git a/csrc/options.h b/csrc/options.h index 95ee5fed511..23143f6de74 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -104,7 +104,6 @@ enum class EnableOption { KernelProfile, //! Enable intra-kernel performance profiling MemoryPromotion, //! Enable promotion of memory types for non-pointwise ops ReuseZeroedMemory, //! Re-use zeroed memory used for grid synchronization - ResizeScheduler, //! Enable the resize scheduler StaticFusionCount, //! Enable using single static count in kernel name WaitDebugger, // Used for debugging multi-GPU. The rank given in the argument // will wait for `gdb attach` at the start. @@ -147,6 +146,7 @@ enum class DisableOption { //! need this in particular to investigate possible conflicts //! between nvFuser communicator and the framework also setting //! up `c10d::ProcessGroup` + ResizeScheduler, //! Disable the resize scheduler EndOfOption //! Placeholder for counting the number of elements }; diff --git a/csrc/preseg_passes/pre_segmenter.cpp b/csrc/preseg_passes/pre_segmenter.cpp index 8017e116f6e..042f03191f7 100644 --- a/csrc/preseg_passes/pre_segmenter.cpp +++ b/csrc/preseg_passes/pre_segmenter.cpp @@ -69,7 +69,7 @@ namespace nvfuser::preseg_passes { // currently only limited to pointwise patterns and does not // support, for example, reductions, etc, so this preseg pass still // may be preferable in some cases. - if (!isOptionEnabled(EnableOption::ResizeScheduler)) { + if (isOptionDisabled(DisableOption::ResizeScheduler)) { OptimizationPass::runPass(fusion); } // NOTE vvv this doesn't really work, since our type promotion to higher diff --git a/csrc/scheduler/resize.cpp b/csrc/scheduler/resize.cpp index 4b1d892e4e6..dfff12c51c2 100644 --- a/csrc/scheduler/resize.cpp +++ b/csrc/scheduler/resize.cpp @@ -62,13 +62,10 @@ std::pair getLargestTensor( } // namespace bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) { -#if 0 - if (!isOptionEnabled(EnableOption::ResizeScheduler)) { - scheduler_debug_utils::canScheduleRejectReason( - schedulerType(), "Not enabled"); + if (isOptionDisabled(DisableOption::ResizeScheduler)) { + scheduler_debug_utils::canScheduleRejectReason(schedulerType(), "Disabled"); return false; } -#endif if (!scheduler_tools::hasResizeBasedOps(fusion)) { scheduler_debug_utils::canScheduleRejectReason( diff --git a/tests/cpp/test_move_pad.cpp b/tests/cpp/test_move_pad.cpp index deebb7fd2e6..8987b0fc7a4 100644 --- a/tests/cpp/test_move_pad.cpp +++ b/tests/cpp/test_move_pad.cpp @@ -21,7 +21,13 @@ using testing::IsTrue; using testing::Property; using testing::UnorderedElementsAre; -using MovePadTest = NVFuserTest; +class MovePadTest : public NVFuserTest { + protected: + void SetUp() override { + DisableOptionsGuard::getCurOptions().set(DisableOption::ResizeScheduler); + NVFuserTest::SetUp(); + } +}; TEST_F(MovePadTest, UnaryCat) { auto fusion = std::make_unique(); @@ -116,7 +122,8 @@ TEST_F(MovePadTest, BinaryBroadcastOnNonCatDim) { EXPECT_THAT( runtime->fusionSegments()->groups(), UnorderedElementsAre( - HeuristicIs(SchedulerType::Resize))); + HeuristicIs(SchedulerType::NoOp), + HeuristicIs(SchedulerType::PointWise))); testValidate( executor_cache.fusion(), out_tensors, aten_inputs, __LINE__, __FILE__); diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index eb6a95c3a5a..3fbc3b9011b 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -55,21 +55,9 @@ void checkLoopDomainEquivalence( } // namespace -class ResizeTest : public NVFuserTest { - protected: - void SetUp() override { - EnableOptionsGuard::getCurOptions().set(EnableOption::ResizeScheduler); - NVFuserTest::SetUp(); - } -}; +using ResizeTest = NVFuserTest; -class ResizeSchedulerTest : public NVFuserFixtureParamTest { - protected: - void SetUp() override { - EnableOptionsGuard::getCurOptions().set(EnableOption::ResizeScheduler); - NVFuserFixtureParamTest::SetUp(); - } -}; +using ResizeSchedulerTest = NVFuserFixtureParamTest; using testing::Each; using testing::HasSubstr; @@ -5763,7 +5751,7 @@ TEST_F(ResizeTest, TraversalForInliningPosition) { // Disable the resize schedule because the original issue happened // with the pointwise scheduler - EnableOptionsGuard::getCurOptions().unset(EnableOption::ResizeScheduler); + DisableOptionsGuard::getCurOptions().unset(DisableOption::ResizeScheduler); auto tv0 = makeContigConcreteTensor({16}); fusion.addInput(tv0); @@ -5865,7 +5853,7 @@ TEST_F(ResizeTest, Repro3801) { // Disable the resize schedule because the original issue happened // with the pointwise scheduler - EnableOptionsGuard::getCurOptions().unset(EnableOption::ResizeScheduler); + DisableOptionsGuard::getCurOptions().unset(DisableOption::ResizeScheduler); auto T13 = makeContigConcreteTensor({1, 16}); fusion.addInput(T13); diff --git a/tests/cpp/test_rope.cpp b/tests/cpp/test_rope.cpp index aa71d874ae3..2a1c7ac0fb7 100644 --- a/tests/cpp/test_rope.cpp +++ b/tests/cpp/test_rope.cpp @@ -46,13 +46,7 @@ struct RopeConfig { } }; -class RopeTest : public NVFuserFixtureParamTest { - protected: - void SetUp() override { - EnableOptionsGuard::getCurOptions().set(EnableOption::ResizeScheduler); - NVFuserTest::SetUp(); - } -}; +using RopeTest = NVFuserFixtureParamTest; using MistralRopeTest = RopeTest; From 4cb7fbd7833b2707e9cb8953d3224718ba04dcf0 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 10 Feb 2025 14:04:55 -0800 Subject: [PATCH 09/32] fix --- csrc/scheduler/resize.cpp | 3 ++- .../scheduler/tools/loop_domain_scheduler.cpp | 24 +++++++++++-------- csrc/scheduler/tools/loop_domain_scheduler.h | 4 +++- tests/cpp/test_resize.cpp | 4 ++-- 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/csrc/scheduler/resize.cpp b/csrc/scheduler/resize.cpp index dfff12c51c2..f9e60a34d93 100644 --- a/csrc/scheduler/resize.cpp +++ b/csrc/scheduler/resize.cpp @@ -299,7 +299,8 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { // The tensors are going to be reordered to align with the largest // input. To make it work, merge operations for reshape should be // cancelled. - scheduler_tools::cancelReshapeInLoopDomains(largest_input); + scheduler_tools::cancelReshapeInLoopDomains( + largest_input, /*skip_innermost_id=*/true); } for (auto expr : fusion->exprs()) { diff --git a/csrc/scheduler/tools/loop_domain_scheduler.cpp b/csrc/scheduler/tools/loop_domain_scheduler.cpp index 2b8fddde023..ccab7fce0e1 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.cpp +++ b/csrc/scheduler/tools/loop_domain_scheduler.cpp @@ -588,7 +588,7 @@ void scheduleLoopDomainsBy( return; } -void cancelReshapeInLoopDomains(TensorView* from_tv) { +void cancelReshapeInLoopDomains(TensorView* from_tv, bool skip_innermost_id) { Fusion* fusion = from_tv->fusion(); IdModel id_model(fusion, /*build_graphs=*/false); id_model.buildExactGraph(); @@ -683,14 +683,17 @@ void cancelReshapeInLoopDomains(TensorView* from_tv) { {reshape_out->getLogicalDomain().begin(), reshape_out->getLogicalDomain().end()}); - auto reshape_exprs_with_innermost_logical_id = - DependencyCheck::getAllExprsBetween( - {reshape_out->getRootDomain().begin(), - reshape_out->getRootDomain().end()}, - {reshape_out->getLogicalDomain().back()}); - std::unordered_set reshape_exprs_with_innermost_logical_id_set = { - reshape_exprs_with_innermost_logical_id.begin(), - reshape_exprs_with_innermost_logical_id.end()}; + std::unordered_set reshape_exprs_with_innermost_logical_id_set; + if (skip_innermost_id) { + auto reshape_exprs_with_innermost_logical_id = + DependencyCheck::getAllExprsBetween( + {reshape_out->getRootDomain().begin(), + reshape_out->getRootDomain().end()}, + {reshape_out->getLogicalDomain().back()}); + reshape_exprs_with_innermost_logical_id_set = { + reshape_exprs_with_innermost_logical_id.begin(), + reshape_exprs_with_innermost_logical_id.end()}; + } auto reshape_out_loop_domain = reshape_out->getLoopDomain(); @@ -699,7 +702,8 @@ void cancelReshapeInLoopDomains(TensorView* from_tv) { ++reshape_exprs_it) { auto reshape_expr = *reshape_exprs_it; - if (reshape_exprs_with_innermost_logical_id_set.count(reshape_expr)) { + if (skip_innermost_id && + reshape_exprs_with_innermost_logical_id_set.count(reshape_expr)) { continue; } diff --git a/csrc/scheduler/tools/loop_domain_scheduler.h b/csrc/scheduler/tools/loop_domain_scheduler.h index d8fc635bc95..9a2e8a14764 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.h +++ b/csrc/scheduler/tools/loop_domain_scheduler.h @@ -107,7 +107,9 @@ void scheduleLoopDomainsBy( // iter domain is reduced, the split needs to remain. If a reshape // only consists of merge transforms, cancellation should be possible, // but that is not currently supported. -void cancelReshapeInLoopDomains(TensorView* from_tv); +void cancelReshapeInLoopDomains( + TensorView* from_tv, + bool skip_innermost_id = false); std::optional getInnermostCancelableReshapePosition( TensorView* from_tv); diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index 3fbc3b9011b..e44ac915610 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -5751,7 +5751,7 @@ TEST_F(ResizeTest, TraversalForInliningPosition) { // Disable the resize schedule because the original issue happened // with the pointwise scheduler - DisableOptionsGuard::getCurOptions().unset(DisableOption::ResizeScheduler); + DisableOptionsGuard::getCurOptions().set(DisableOption::ResizeScheduler); auto tv0 = makeContigConcreteTensor({16}); fusion.addInput(tv0); @@ -5853,7 +5853,7 @@ TEST_F(ResizeTest, Repro3801) { // Disable the resize schedule because the original issue happened // with the pointwise scheduler - DisableOptionsGuard::getCurOptions().unset(DisableOption::ResizeScheduler); + DisableOptionsGuard::getCurOptions().set(DisableOption::ResizeScheduler); auto T13 = makeContigConcreteTensor({1, 16}); fusion.addInput(T13); From e1b17909e843b6b10b5683d4c441f75a42603445 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 12 Feb 2025 21:52:34 -0800 Subject: [PATCH 10/32] fix --- csrc/scheduler/tools/loop_domain_scheduler.cpp | 2 -- csrc/scheduler/tools/resize_utils.cpp | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/csrc/scheduler/tools/loop_domain_scheduler.cpp b/csrc/scheduler/tools/loop_domain_scheduler.cpp index ccab7fce0e1..c79ad79245e 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.cpp +++ b/csrc/scheduler/tools/loop_domain_scheduler.cpp @@ -528,12 +528,10 @@ void scheduleLoopDomainsBy( Direction replay_dir_tv = Direction::Undefined; if (replay_dir != Direction::Backward && input_ids.size() == transform->inputs().size()) { - NVF_ERROR(output_ids.empty()); replay_dir_tv = Direction::Forward; } else if ( replay_dir != Direction::Forward && output_ids.size() == transform->outputs().size()) { - NVF_ERROR(input_ids.empty()); replay_dir_tv = Direction::Backward; } else { // Replay not possible since none of inputs nor outputs are connected with diff --git a/csrc/scheduler/tools/resize_utils.cpp b/csrc/scheduler/tools/resize_utils.cpp index f281b40fe79..11b7c660be8 100644 --- a/csrc/scheduler/tools/resize_utils.cpp +++ b/csrc/scheduler/tools/resize_utils.cpp @@ -72,7 +72,8 @@ void propagateResizeToInputs(Expr* resize_tensor_op) { continue; } - scheduler_tools::scheduleLoopDomainsBy(tvs_to_schedule, resize); + scheduler_tools::scheduleLoopDomainsBy( + tvs_to_schedule, resize, Direction::Forward); } } From c79e234fe151f870b8b8b76f326e5a6b54d5a82a Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 12 Feb 2025 23:11:54 -0800 Subject: [PATCH 11/32] cleanup --- csrc/scheduler/resize.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/scheduler/resize.cpp b/csrc/scheduler/resize.cpp index f9e60a34d93..b8768c44c34 100644 --- a/csrc/scheduler/resize.cpp +++ b/csrc/scheduler/resize.cpp @@ -481,7 +481,6 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { auto vec_ref_tv = ref_tv; const auto tvs_to_vectorize = scheduler_utils::getInputsOutputsWithInnerDim(vec_ref_tv, true, true); - std::cerr << "TVs to vec: " << toDelimitedString(tvs_to_vectorize) << "\n"; for (auto tv_to_vectorize : tvs_to_vectorize) { if (tv_to_vectorize->isFusionInput()) { for (auto consumer_tv : ir_utils::consumerTvsOf(tv_to_vectorize)) { From c5cc1ba248a2720cfb80046d3ef7a649691e8446 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 13 Feb 2025 16:57:14 -0800 Subject: [PATCH 12/32] WIP: extent simplification --- csrc/ops/alias.cpp | 78 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 67 insertions(+), 11 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 5729fed5b3f..1544cf75dd4 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -747,17 +747,59 @@ TensorView* slice( ", Expected: ", ndims); - const auto normalize_slice_range = [&manual_normalization]( + const auto get_int = [](Val* x) -> std::optional { + if (x->isConstInt()) { + return x->evaluate().as(); + } else { + return std::nullopt; + } + }; + + // Simplify patterns like min(min(x, 32), 32) to min(x, 32) as it + // isn't uncommon + const auto min_expr = [&](Val* x, Val* y) -> Val* { + auto y_int = get_int(y); + auto bop = dynamic_cast(x->definition()); + if (y_int != std::nullopt && bop != nullptr && + bop->getBinaryOpType() == BinaryOpType::Min) { + if (auto lhs_int = get_int(bop->lhs()); lhs_int != std::nullopt) { + return SimplifyingIrBuilder::minExpr( + bop->rhs(), IrBuilder::create(std::min(*lhs_int, *y_int))); + } else if (auto rhs_int = get_int(bop->rhs()); rhs_int != std::nullopt) { + return SimplifyingIrBuilder::minExpr( + bop->lhs(), IrBuilder::create(std::min(*rhs_int, *y_int))); + } + } + + return SimplifyingIrBuilder::minExpr(x, y); + }; + + const auto normalize_slice_range = [&manual_normalization, &min_expr]( Slice range, Val* extent) -> Slice { + std::optional extent_int; + if (extent->isConstInt()) { + extent_int = extent->evaluate().as(); + } + auto cast_extent = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, extent); auto zero = FusionGuard::getCurFusion()->zeroVal(DataType::Index); + std::optional start_int; + if (range.start->isConstInt()) { + start_int = range.start->evaluate().as(); + } + std::optional stop_int; + if (range.stop->isConstInt()) { + stop_int = range.start->evaluate().as(); + } + // norm_start = max(0, start < 0 ? start + extent : start) if (range.start == nullptr) { range.start = zero; - } else if (!range.start->isZeroInt()) { + start_int = 0; + } else if (start_int != 0) { range.start = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.start); if (!manual_normalization) { @@ -768,23 +810,37 @@ TensorView* slice( SimplifyingIrBuilder::addExpr(range.start, cast_extent), range.start)); } + if (range.start->isConstInt()) { + start_int = range.start->evaluate().as(); + } } + if (range.stop) { + std::cerr << "Normalizing stop from: " << range.stop->toString() + << ", cast extent: " << cast_extent->toInlineString() << "\n"; + } // norm_stop = max(norm_start, min(extent, stop < 0 ? stop + extent : stop) if (range.stop == nullptr) { range.stop = cast_extent; } else if (!range.stop->sameAs(extent)) { range.stop = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.stop); - if (!manual_normalization) { - range.stop = SimplifyingIrBuilder::maxExpr( - range.start, - SimplifyingIrBuilder::minExpr( - cast_extent, - SimplifyingIrBuilder::whereExpr( - SimplifyingIrBuilder::ltExpr(range.stop, zero), - SimplifyingIrBuilder::addExpr(range.stop, cast_extent), - range.stop))); + // Commonly, range.start is zero and stop is non negative + if (start_int == 0 && stop_int >= 0) { + range.stop = min_expr(cast_extent, range.stop); + } else { + if (!manual_normalization) { + range.stop = SimplifyingIrBuilder::maxExpr( + range.start, + min_expr( + cast_extent, + SimplifyingIrBuilder::whereExpr( + SimplifyingIrBuilder::ltExpr(range.stop, zero), + SimplifyingIrBuilder::addExpr(range.stop, cast_extent), + range.stop))); + std::cerr << "Normalized to : " << range.stop->toInlineString() + << "\n"; + } } } From cc4f1c2f81e5c1343bf8fd541219e0ce21720ea3 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 13 Feb 2025 17:21:43 -0800 Subject: [PATCH 13/32] further simplification --- csrc/ir/builder.cpp | 11 +++++++++++ csrc/ops/alias.cpp | 6 ++++++ 2 files changed, 17 insertions(+) diff --git a/csrc/ir/builder.cpp b/csrc/ir/builder.cpp index ef8b95a70c8..5121a6104ce 100644 --- a/csrc/ir/builder.cpp +++ b/csrc/ir/builder.cpp @@ -380,6 +380,17 @@ Val* SimplifyingIrBuilder::addExpr(Val* lhs, Val* rhs) { } else if (rhs->isConst()) { return addExpr(lhs, rhs->value(), rhs->dtype()); } else { + // Simplify x + (-x) to 0 + if (auto neg_expr = dynamic_cast(lhs->definition()); + neg_expr != nullptr && neg_expr->getUnaryOpType() == UnaryOpType::Neg && + neg_expr->in()->sameAs(rhs)) { + return lhs->fusion()->zeroVal(lhs->dtype()); + } else if (auto neg_expr = dynamic_cast(rhs->definition()); + neg_expr != nullptr && + neg_expr->getUnaryOpType() == UnaryOpType::Neg && + neg_expr->in()->sameAs(lhs)) { + return lhs->fusion()->zeroVal(lhs->dtype()); + } return IrBuilder::addExpr(lhs, rhs); } } diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 1544cf75dd4..46192b4fc2e 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -758,7 +758,13 @@ TensorView* slice( // Simplify patterns like min(min(x, 32), 32) to min(x, 32) as it // isn't uncommon const auto min_expr = [&](Val* x, Val* y) -> Val* { + auto x_int = get_int(x); auto y_int = get_int(y); + if (x_int == 0) { + return x; + } else if (y_int == 0) { + return y; + } auto bop = dynamic_cast(x->definition()); if (y_int != std::nullopt && bop != nullptr && bop->getBinaryOpType() == BinaryOpType::Min) { From 4994d8487a0e84615db9653f20c4d063397653ee Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 13 Feb 2025 18:05:44 -0800 Subject: [PATCH 14/32] Do some more simplifications specific to extents --- csrc/ir/builder.cpp | 11 +++++++ csrc/ops/alias.cpp | 73 +++++++++++++++++++++++++++++++++++++-------- 2 files changed, 72 insertions(+), 12 deletions(-) diff --git a/csrc/ir/builder.cpp b/csrc/ir/builder.cpp index ef8b95a70c8..5121a6104ce 100644 --- a/csrc/ir/builder.cpp +++ b/csrc/ir/builder.cpp @@ -380,6 +380,17 @@ Val* SimplifyingIrBuilder::addExpr(Val* lhs, Val* rhs) { } else if (rhs->isConst()) { return addExpr(lhs, rhs->value(), rhs->dtype()); } else { + // Simplify x + (-x) to 0 + if (auto neg_expr = dynamic_cast(lhs->definition()); + neg_expr != nullptr && neg_expr->getUnaryOpType() == UnaryOpType::Neg && + neg_expr->in()->sameAs(rhs)) { + return lhs->fusion()->zeroVal(lhs->dtype()); + } else if (auto neg_expr = dynamic_cast(rhs->definition()); + neg_expr != nullptr && + neg_expr->getUnaryOpType() == UnaryOpType::Neg && + neg_expr->in()->sameAs(lhs)) { + return lhs->fusion()->zeroVal(lhs->dtype()); + } return IrBuilder::addExpr(lhs, rhs); } } diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 5729fed5b3f..fd543e37df4 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -747,17 +747,58 @@ TensorView* slice( ", Expected: ", ndims); - const auto normalize_slice_range = [&manual_normalization]( - Slice range, Val* extent) -> Slice { + const auto get_int = [](Val* x) -> std::optional { + if (x != nullptr && x->isConstInt()) { + return x->evaluate().as(); + } else { + return std::nullopt; + } + }; + + // Specialized min for extents. Do some more simplification beyond + // SimplifyingIrBuilder that are only valid for extents. + const auto min_extents = [&](Val* x, Val* y) -> Val* { + auto x_int = get_int(x); + auto y_int = get_int(y); + // Since extents are never negative, if one is 0, that must be the mininum. + if (x_int == 0) { + return x; + } else if (y_int == 0) { + return y; + } + // Simplify patterns like min(min(x, 32), 32) to min(x, 32) as it + // isn't uncommon. + auto bop = dynamic_cast(x->definition()); + if (y_int != std::nullopt && bop != nullptr && + bop->getBinaryOpType() == BinaryOpType::Min) { + if (auto lhs_int = get_int(bop->lhs()); lhs_int != std::nullopt) { + return SimplifyingIrBuilder::minExpr( + bop->rhs(), IrBuilder::create(std::min(*lhs_int, *y_int))); + } else if (auto rhs_int = get_int(bop->rhs()); rhs_int != std::nullopt) { + return SimplifyingIrBuilder::minExpr( + bop->lhs(), IrBuilder::create(std::min(*rhs_int, *y_int))); + } + } + + return SimplifyingIrBuilder::minExpr(x, y); + }; + + const auto normalize_slice_range = + [&manual_normalization, &min_extents, &get_int]( + Slice range, Val* extent) -> Slice { auto cast_extent = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, extent); auto zero = FusionGuard::getCurFusion()->zeroVal(DataType::Index); + auto start_int = get_int(range.start); + auto stop_int = get_int(range.stop); + // norm_start = max(0, start < 0 ? start + extent : start) if (range.start == nullptr) { range.start = zero; - } else if (!range.start->isZeroInt()) { + start_int = 0; + } else if (start_int != 0) { range.start = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.start); if (!manual_normalization) { @@ -768,6 +809,9 @@ TensorView* slice( SimplifyingIrBuilder::addExpr(range.start, cast_extent), range.start)); } + if (range.start->isConstInt()) { + start_int = range.start->evaluate().as(); + } } // norm_stop = max(norm_start, min(extent, stop < 0 ? stop + extent : stop) @@ -776,15 +820,20 @@ TensorView* slice( } else if (!range.stop->sameAs(extent)) { range.stop = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.stop); - if (!manual_normalization) { - range.stop = SimplifyingIrBuilder::maxExpr( - range.start, - SimplifyingIrBuilder::minExpr( - cast_extent, - SimplifyingIrBuilder::whereExpr( - SimplifyingIrBuilder::ltExpr(range.stop, zero), - SimplifyingIrBuilder::addExpr(range.stop, cast_extent), - range.stop))); + // Commonly, range.start is zero and stop is non negative + if (start_int == 0 && stop_int >= 0) { + range.stop = min_extents(cast_extent, range.stop); + } else { + if (!manual_normalization) { + range.stop = SimplifyingIrBuilder::maxExpr( + range.start, + min_extents( + cast_extent, + SimplifyingIrBuilder::whereExpr( + SimplifyingIrBuilder::ltExpr(range.stop, zero), + SimplifyingIrBuilder::addExpr(range.stop, cast_extent), + range.stop))); + } } } From cd59f29943b951470592623085bc16b7681c3aa1 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 13 Feb 2025 19:16:49 -0800 Subject: [PATCH 15/32] test fix --- tests/cpp/test_gpu3.cpp | 22 ++++++++++++++++++++++ tests/cpp/test_resize.cpp | 4 ++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index 69b11a5d69f..43d3bc4ed19 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -9342,6 +9342,28 @@ TEST_F(NVFuserTest, RegisteredExactMappingWithExtentReplacment) { } } +TEST_F(NVFuserTest, TMP) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {1}); + fusion.addOutput(tv1); + + tv1->split(1, 4); + auto rf_tv = tv1->rFactor({-1}); + std::cerr << "RF: " << rf_tv->toString() << "\n"; + + fusion.print(); + + ComputeAtMap ca_map(&fusion); + scheduler_utils::propagateReshapeTransforms(&fusion, ca_map); + + fusion.print(); +} + // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index a4988592269..d70bb7c17e8 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -1397,7 +1397,7 @@ TEST_F(ResizeTest, SliceExtentSimplification) { // By default, the extent of the tv1 domain is: // i0 + ( ( fmax(0, ( fmin(i0, 1) )) ) + ( -i0 ) ) // This should be simplified to just: - // fmax(0, ( fmin(i0, 1) )) + // fmin(i0, 1) fusion.addOutput(tv1); @@ -1405,7 +1405,7 @@ TEST_F(ResizeTest, SliceExtentSimplification) { auto bop = dynamic_cast(resize_extent->definition()); ASSERT_TRUE(bop != nullptr) << "Unexpected resize output extent: " << resize_extent->toInlineString(); - EXPECT_EQ(bop->getBinaryOpType(), BinaryOpType::Max) + EXPECT_EQ(bop->getBinaryOpType(), BinaryOpType::Min) << "Unexpected resize output extent: " << resize_extent->toInlineString(); } From 5dc95f45d3aa0cf87610e5cd1ae254d0fb40f3f6 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 13 Feb 2025 18:03:59 -0800 Subject: [PATCH 16/32] cleanup --- csrc/ops/alias.cpp | 41 +++++++++++++-------------------------- tests/cpp/test_resize.cpp | 4 ++-- 2 files changed, 16 insertions(+), 29 deletions(-) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 46192b4fc2e..fd543e37df4 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -748,23 +748,26 @@ TensorView* slice( ndims); const auto get_int = [](Val* x) -> std::optional { - if (x->isConstInt()) { + if (x != nullptr && x->isConstInt()) { return x->evaluate().as(); } else { return std::nullopt; } }; - // Simplify patterns like min(min(x, 32), 32) to min(x, 32) as it - // isn't uncommon - const auto min_expr = [&](Val* x, Val* y) -> Val* { + // Specialized min for extents. Do some more simplification beyond + // SimplifyingIrBuilder that are only valid for extents. + const auto min_extents = [&](Val* x, Val* y) -> Val* { auto x_int = get_int(x); auto y_int = get_int(y); + // Since extents are never negative, if one is 0, that must be the mininum. if (x_int == 0) { return x; } else if (y_int == 0) { return y; } + // Simplify patterns like min(min(x, 32), 32) to min(x, 32) as it + // isn't uncommon. auto bop = dynamic_cast(x->definition()); if (y_int != std::nullopt && bop != nullptr && bop->getBinaryOpType() == BinaryOpType::Min) { @@ -780,26 +783,16 @@ TensorView* slice( return SimplifyingIrBuilder::minExpr(x, y); }; - const auto normalize_slice_range = [&manual_normalization, &min_expr]( - Slice range, Val* extent) -> Slice { - std::optional extent_int; - if (extent->isConstInt()) { - extent_int = extent->evaluate().as(); - } - + const auto normalize_slice_range = + [&manual_normalization, &min_extents, &get_int]( + Slice range, Val* extent) -> Slice { auto cast_extent = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, extent); auto zero = FusionGuard::getCurFusion()->zeroVal(DataType::Index); - std::optional start_int; - if (range.start->isConstInt()) { - start_int = range.start->evaluate().as(); - } - std::optional stop_int; - if (range.stop->isConstInt()) { - stop_int = range.start->evaluate().as(); - } + auto start_int = get_int(range.start); + auto stop_int = get_int(range.stop); // norm_start = max(0, start < 0 ? start + extent : start) if (range.start == nullptr) { @@ -821,10 +814,6 @@ TensorView* slice( } } - if (range.stop) { - std::cerr << "Normalizing stop from: " << range.stop->toString() - << ", cast extent: " << cast_extent->toInlineString() << "\n"; - } // norm_stop = max(norm_start, min(extent, stop < 0 ? stop + extent : stop) if (range.stop == nullptr) { range.stop = cast_extent; @@ -833,19 +822,17 @@ TensorView* slice( SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.stop); // Commonly, range.start is zero and stop is non negative if (start_int == 0 && stop_int >= 0) { - range.stop = min_expr(cast_extent, range.stop); + range.stop = min_extents(cast_extent, range.stop); } else { if (!manual_normalization) { range.stop = SimplifyingIrBuilder::maxExpr( range.start, - min_expr( + min_extents( cast_extent, SimplifyingIrBuilder::whereExpr( SimplifyingIrBuilder::ltExpr(range.stop, zero), SimplifyingIrBuilder::addExpr(range.stop, cast_extent), range.stop))); - std::cerr << "Normalized to : " << range.stop->toInlineString() - << "\n"; } } } diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index e44ac915610..4301f4de3f5 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -1385,7 +1385,7 @@ TEST_F(ResizeTest, SliceExtentSimplification) { // By default, the extent of the tv1 domain is: // i0 + ( ( fmax(0, ( fmin(i0, 1) )) ) + ( -i0 ) ) // This should be simplified to just: - // fmax(0, ( fmin(i0, 1) )) + // fmin(i0, 1) fusion.addOutput(tv1); @@ -1393,7 +1393,7 @@ TEST_F(ResizeTest, SliceExtentSimplification) { auto bop = dynamic_cast(resize_extent->definition()); ASSERT_TRUE(bop != nullptr) << "Unexpected resize output extent: " << resize_extent->toInlineString(); - EXPECT_EQ(bop->getBinaryOpType(), BinaryOpType::Max) + EXPECT_EQ(bop->getBinaryOpType(), BinaryOpType::Min) << "Unexpected resize output extent: " << resize_extent->toInlineString(); } From b979fc58e0be68ccc1d50b5c3fd9a00a4c62d00c Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 13 Feb 2025 22:42:17 -0800 Subject: [PATCH 17/32] cleanup --- tests/cpp/test_gpu3.cpp | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index 43d3bc4ed19..69b11a5d69f 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -9342,28 +9342,6 @@ TEST_F(NVFuserTest, RegisteredExactMappingWithExtentReplacment) { } } -TEST_F(NVFuserTest, TMP) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {1}); - fusion.addOutput(tv1); - - tv1->split(1, 4); - auto rf_tv = tv1->rFactor({-1}); - std::cerr << "RF: " << rf_tv->toString() << "\n"; - - fusion.print(); - - ComputeAtMap ca_map(&fusion); - scheduler_utils::propagateReshapeTransforms(&fusion, ca_map); - - fusion.print(); -} - // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser From 339aac314ef840532b46887d2ce291a72d82c5a9 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 14 Feb 2025 15:30:39 -0800 Subject: [PATCH 18/32] WIP: fix --- csrc/predicate_compute.cpp | 164 +++++++++++++++++++++++++++++++++++-- csrc/val_graph_visitor.h | 16 ++++ 2 files changed, 171 insertions(+), 9 deletions(-) diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index 1500e90a5d5..55c84396359 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -10,6 +10,8 @@ #include #include #include +#include +#include #include #include #include @@ -40,8 +42,17 @@ bool isOutputLocal(const Expr* expr) { } // namespace bool ParallelizedDomainPredicate::PredicateInfo::addDomain(IterDomain* id) { +#if 1 auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( id, IdMappingMode::EXACT); +#else + auto concrete_id = GpuLower::current() + ->idModel() + .idGraph(IdMappingMode::EXACT) + .toGroup(id) + ->front() + ->as(); +#endif if (std::find(ids_.begin(), ids_.end(), concrete_id) == ids_.end()) { ids_.push_back(concrete_id); return true; @@ -72,17 +83,17 @@ Val* ParallelizedDomainPredicate::PredicateInfo::getPredicate() const { namespace { std::unordered_set getNonUnswitchedRootDomains( + const Expr* expr, const std::vector& loops, size_t unswitched_loop_index) { - std::vector non_unswited_loop_domains; + std::vector non_unswitched_loop_domains; std::transform( loops.begin(), loops.begin() + (int64_t)unswitched_loop_index, - std::back_inserter(non_unswited_loop_domains), + std::back_inserter(non_unswitched_loop_domains), [&](ForLoop* loop) { return loop->iter_domain(); }); - auto non_unswitched_inputs = - IterVisitor::getInputsTo(non_unswited_loop_domains); + IterVisitor::getInputsTo(non_unswitched_loop_domains); auto non_unswitched_root_doms = ir_utils::filterByType(non_unswitched_inputs); @@ -95,7 +106,7 @@ std::unordered_set getNonUnswitchedRootDomains( std::inserter( non_unswitched_concrete_root_domains, non_unswitched_concrete_root_domains.end()), - [&](auto root_dom) { + [&](IterDomain* root_dom) { return GpuLower::current()->caMap()->getConcreteMappedID( root_dom, IdMappingMode::EXACT); }); @@ -111,7 +122,7 @@ bool isFullyUnswitched( auto root_domains = ir_utils::filterByType(root_vals); return std::none_of( - root_domains.begin(), root_domains.end(), [&](auto root_dom) { + root_domains.begin(), root_domains.end(), [&](IterDomain* root_dom) { auto concrete_root_dom = GpuLower::current()->caMap()->getConcreteMappedID( root_dom, IdMappingMode::EXACT); @@ -119,6 +130,100 @@ bool isFullyUnswitched( }); } +std::vector getFullyUnswitchedLoopIds( + const Expr* expr, + const std::vector& loops, + ForLoop* unswitched_loop) { + if (unswitched_loop == nullptr) { + return {}; + } + + const auto& id_model = GpuLower::current()->idModel(); + const auto& indexing_graph = + id_model.idGraph(TensorIndexer::traversalGraphType()); + + auto out_tv = ir_utils::getTvOutput(expr); + NVF_ERROR(out_tv != nullptr); + + std::vector loop_ids; + loop_ids.reserve(loops.size()); + std::transform( + loops.begin(), + loops.end(), + std::back_inserter(loop_ids), + [&](ForLoop* loop) { return loop->iter_domain(); }); + + const auto predicate_ids = getPredicateDomains(out_tv, expr); + + const IndexingTraversal::ExprPath predicate_path = + IndexingTraversal::getExprsBetween( + expr, indexing_graph, loop_ids, predicate_ids); + + ValGroups non_unswitch_dep_ids; + std::vector unswitched_loop_ids; + bool unswitch_found = false; + for (const auto loop : loops) { + if (loop == unswitched_loop) { + unswitch_found = true; + } + if (unswitch_found) { + unswitched_loop_ids.push_back(loop->iter_domain()); + } else { + non_unswitch_dep_ids.pushBack( + indexing_graph.toGroup(loop->iter_domain())); + } + } + + for (const auto& [expr_g, dir] : predicate_path) { + const auto inputs = getInputsOfExprGroup(indexing_graph, expr_g, dir); + const auto outputs = getOutputsOfExprGroup(indexing_graph, expr_g, dir); + if (std::any_of(inputs.begin(), inputs.end(), [&](const ValGroup& input) { + return non_unswitch_dep_ids.has(input); + })) { + // Depends on non-unswitched ids + non_unswitch_dep_ids.pushBack(outputs); + } + } + + // If none of unswitched_loop_ids is used with the non-unswitched + // loop ids, + + std::vector fully_unswitched_loop_ids; + for (auto unswitched_loop_id : unswitched_loop_ids) { + if (!isParallelTypeThread(unswitched_loop_id->getParallelType())) { + continue; + } + + ValGroups unswitch_dep_ids; + unswitch_dep_ids.pushBack(indexing_graph.toGroup(unswitched_loop_id)); + + bool conflict_found = false; + for (const auto& [expr_g, dir] : predicate_path) { + const auto inputs = getInputsOfExprGroup(indexing_graph, expr_g, dir); + const auto outputs = getOutputsOfExprGroup(indexing_graph, expr_g, dir); + if (std::none_of( + inputs.begin(), inputs.end(), [&](const ValGroup& input) { + return unswitch_dep_ids.has(input); + })) { + continue; + } + + if (std::any_of(inputs.begin(), inputs.end(), [&](const ValGroup& input) { + return non_unswitch_dep_ids.has(input); + })) { + conflict_found = true; + break; + } + } + + if (!conflict_found) { + fully_unswitched_loop_ids.push_back(unswitched_loop_id); + } + } + + return fully_unswitched_loop_ids; +} + } // namespace std::unordered_map @@ -147,13 +252,19 @@ ParallelizedDomainPredicate::getPredicateMap( bool within_unswitch = false; std::unordered_set non_unswitched_root_domains; + auto fully_unswitched_loop_ids = + getFullyUnswitchedLoopIds(expr, loops, unswitched_loop); + for (const auto i : c10::irange(loops.size())) { auto loop = loops[i]; // Parallel dimensions need not be predicated if fully unswitched. if (loop == unswitched_loop) { within_unswitch = true; - non_unswitched_root_domains = getNonUnswitchedRootDomains(loops, i); +#if 0 + non_unswitched_root_domains = getNonUnswitchedRootDomains( + expr, loops, i); +#endif } auto loop_id = loop->iter_domain(); @@ -167,10 +278,20 @@ ParallelizedDomainPredicate::getPredicateMap( auto parallel_dim = gpu_lower->parallelDimensionMap().getRaw(loop_ptype); // Parallel dimensions need not be predicated if fully unswitched. +#if 0 if (within_unswitch && isFullyUnswitched(loop_id, non_unswitched_root_domains)) { continue; } +#else + if (within_unswitch && + std::find( + fully_unswitched_loop_ids.begin(), + fully_unswitched_loop_ids.end(), + loop_id) != fully_unswitched_loop_ids.end()) { + continue; + } +#endif for (auto tv : output_tvs) { // Check if the loop domain is used by the output tensor @@ -178,8 +299,15 @@ ParallelizedDomainPredicate::getPredicateMap( tv->getLoopDomain().begin(), tv->getLoopDomain().end(), [&](auto tv_id) { +#if 0 return gpu_lower->caMap()->areMapped( loop_id, tv_id, IdMappingMode::EXACT); +#else + return gpu_lower->idModel() + .idGraph(IdMappingMode::EXACT) + .disjointValSets() + .strictAreMapped(loop_id, tv_id); +#endif }); if (it == tv->getLoopDomain().end()) { continue; @@ -338,10 +466,19 @@ UnswitchPredicateKey::UnswitchPredicateKey( } // Find the corresponding concrete id for each parallel type - for (auto consumer_loop : parallelized_consumer_loop_ids) { + for (IterDomain* consumer_loop : parallelized_consumer_loop_ids) { auto pt = consumer_loop->getParallelType(); +#if 1 auto concrete_loop = GpuLower::current()->caMap()->getConcreteMappedID( consumer_loop, IdMappingMode::EXACT); +#else + auto concrete_loop = GpuLower::current() + ->idModel() + .idGraph(IdMappingMode::EXACT) + .toGroup(consumer_loop) + ->front() + ->as(); +#endif parallel_concrete_ids_.at(pt) = concrete_loop; } } @@ -728,9 +865,18 @@ void UnswitchPredicate::predicateOn(Expr* tv_expr) { UnswitchPredicateKey first_key; bool first_key_set = false; - for (auto root_id : root_ids) { + for (IterDomain* root_id : root_ids) { +#if 1 auto concrete_root_id = gpu_lower->caMap()->getConcreteMappedID( root_id, IdMappingMode::EXACT); +#else + auto concrete_root_id = GpuLower::current() + ->idModel() + .idGraph(IdMappingMode::EXACT) + .toGroup(root_id) + ->front() + ->as(); +#endif if (root_id->isBroadcast()) { continue; diff --git a/csrc/val_graph_visitor.h b/csrc/val_graph_visitor.h index c612948a67b..40e1a9b14bf 100644 --- a/csrc/val_graph_visitor.h +++ b/csrc/val_graph_visitor.h @@ -276,4 +276,20 @@ class ValGraphPermissiveBFS : public BFSWithPermissiveDependence< } }; +inline std::vector getInputsOfExprGroup( + const ValGraph& graph, + const ExprGroup& expr, + Direction dir) { + return getInputsOfExpr( + expr, dir, ValGraphInputs(graph), ValGraphOutputs(graph)); +} + +inline std::vector getOutputsOfExprGroup( + const ValGraph& graph, + const ExprGroup& expr, + Direction dir) { + return getOutputsOfExpr( + expr, dir, ValGraphInputs(graph), ValGraphOutputs(graph)); +} + } // namespace nvfuser From a08ce8ac08f317f87b32785d1f7d6f137a9e62f0 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 14 Feb 2025 19:04:02 -0800 Subject: [PATCH 19/32] debug --- csrc/device_lower/lower2device.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 75d9352c05b..f5a89442189 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -498,6 +498,9 @@ void GpuLower::analysis(Fusion* fusion) { /*allow_self_mapping=*/false, /*validate=*/false); id_model_->validateAndPropagatePType(); + + id_model_->idGraph(IdMappingMode::ALMOSTEXACT).dumpGraphvizDotGraph("almost_exact.dot"); + fusion->print(); } resolveComputeWith(fusion_); From 1d50e73c45097d24b71c6d5c4c4a1cb2d7381d34 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 14 Feb 2025 19:33:57 -0800 Subject: [PATCH 20/32] fix --- csrc/predicate_compute.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index 55c84396359..dccc083d291 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -151,7 +151,16 @@ std::vector getFullyUnswitchedLoopIds( loops.begin(), loops.end(), std::back_inserter(loop_ids), - [&](ForLoop* loop) { return loop->iter_domain(); }); + [&](ForLoop* loop) { + const auto& loop_group = + id_model.idGraph(IdMappingMode::LOOP).toGroup(loop->iter_domain()); + auto promotion_it = id_model.loopPromotionMap().find(loop_group); + NVF_ERROR( + promotion_it != id_model.loopPromotionMap().end(), + "Loop promotion not found for ", + loop->iter_domain()->toString()); + return promotion_it->second; + }); const auto predicate_ids = getPredicateDomains(out_tv, expr); From d72668ed3ba2b6366b840b57df7e4382335a9e20 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sun, 16 Feb 2025 10:59:37 -0800 Subject: [PATCH 21/32] clang-tidy --- csrc/ir/builder.cpp | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/csrc/ir/builder.cpp b/csrc/ir/builder.cpp index 5121a6104ce..756b0e1b98c 100644 --- a/csrc/ir/builder.cpp +++ b/csrc/ir/builder.cpp @@ -381,14 +381,18 @@ Val* SimplifyingIrBuilder::addExpr(Val* lhs, Val* rhs) { return addExpr(lhs, rhs->value(), rhs->dtype()); } else { // Simplify x + (-x) to 0 - if (auto neg_expr = dynamic_cast(lhs->definition()); - neg_expr != nullptr && neg_expr->getUnaryOpType() == UnaryOpType::Neg && - neg_expr->in()->sameAs(rhs)) { - return lhs->fusion()->zeroVal(lhs->dtype()); - } else if (auto neg_expr = dynamic_cast(rhs->definition()); - neg_expr != nullptr && - neg_expr->getUnaryOpType() == UnaryOpType::Neg && - neg_expr->in()->sameAs(lhs)) { + Val* x = nullptr; + auto uop = dynamic_cast(lhs->definition()); + if (uop != nullptr) { + // lhs may be (-x). Pick rhs as x + x = rhs; + } else { + uop = dynamic_cast(rhs->definition()); + // rhs may be (-x). Pick lhs as x + x = lhs; + } + if (uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Neg && + uop->in()->sameAs(x)) { return lhs->fusion()->zeroVal(lhs->dtype()); } return IrBuilder::addExpr(lhs, rhs); From d9fa74e66706e5e373008a48de940e69f839d68f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sun, 16 Feb 2025 17:42:30 -0800 Subject: [PATCH 22/32] remove debug print --- csrc/device_lower/lower2device.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index f5a89442189..75d9352c05b 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -498,9 +498,6 @@ void GpuLower::analysis(Fusion* fusion) { /*allow_self_mapping=*/false, /*validate=*/false); id_model_->validateAndPropagatePType(); - - id_model_->idGraph(IdMappingMode::ALMOSTEXACT).dumpGraphvizDotGraph("almost_exact.dot"); - fusion->print(); } resolveComputeWith(fusion_); From 9f6a79f48c378a49c9cd111fa2ac3ecd606b4fde Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 19 Feb 2025 18:25:03 -0800 Subject: [PATCH 23/32] python test WAR --- tests/python/test_python_frontend.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index e2170b4d16f..3acc21028a5 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -3276,6 +3276,16 @@ def fusion_func(fd: FusionDefinition) -> None: nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, supports_segmentation=False) # self.assertEqual(nvf_out[0], t24) + # This fusion takes a long time to segment and schedule + # because of the resized extents, which seem to stress the + # expression simplifier a lot. Serializing this fusion would + # significantly increase the test time as it would be + # deserialized every time, which includes segmentation and + # scheduling. Ideally, we should optimize the expression + # simplifier, but for now resetting the cache should avoid the + # issue. + FusionCache.reset() + # Test that symbolic IterDomains can be concatenated # https://github.com/NVIDIA/Fuser/issues/1554 def test_cat_symbolic(self): From 2a56bbd955df7f2b28d213725e3deb911af2a3b0 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 20 Feb 2025 19:57:48 -0800 Subject: [PATCH 24/32] cleanup --- csrc/predicate_compute.cpp | 38 -------------------------------------- 1 file changed, 38 deletions(-) diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index dccc083d291..08bd337ab61 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -42,17 +42,8 @@ bool isOutputLocal(const Expr* expr) { } // namespace bool ParallelizedDomainPredicate::PredicateInfo::addDomain(IterDomain* id) { -#if 1 auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID( id, IdMappingMode::EXACT); -#else - auto concrete_id = GpuLower::current() - ->idModel() - .idGraph(IdMappingMode::EXACT) - .toGroup(id) - ->front() - ->as(); -#endif if (std::find(ids_.begin(), ids_.end(), concrete_id) == ids_.end()) { ids_.push_back(concrete_id); return true; @@ -270,10 +261,6 @@ ParallelizedDomainPredicate::getPredicateMap( // Parallel dimensions need not be predicated if fully unswitched. if (loop == unswitched_loop) { within_unswitch = true; -#if 0 - non_unswitched_root_domains = getNonUnswitchedRootDomains( - expr, loops, i); -#endif } auto loop_id = loop->iter_domain(); @@ -287,12 +274,6 @@ ParallelizedDomainPredicate::getPredicateMap( auto parallel_dim = gpu_lower->parallelDimensionMap().getRaw(loop_ptype); // Parallel dimensions need not be predicated if fully unswitched. -#if 0 - if (within_unswitch && - isFullyUnswitched(loop_id, non_unswitched_root_domains)) { - continue; - } -#else if (within_unswitch && std::find( fully_unswitched_loop_ids.begin(), @@ -300,7 +281,6 @@ ParallelizedDomainPredicate::getPredicateMap( loop_id) != fully_unswitched_loop_ids.end()) { continue; } -#endif for (auto tv : output_tvs) { // Check if the loop domain is used by the output tensor @@ -477,17 +457,8 @@ UnswitchPredicateKey::UnswitchPredicateKey( // Find the corresponding concrete id for each parallel type for (IterDomain* consumer_loop : parallelized_consumer_loop_ids) { auto pt = consumer_loop->getParallelType(); -#if 1 auto concrete_loop = GpuLower::current()->caMap()->getConcreteMappedID( consumer_loop, IdMappingMode::EXACT); -#else - auto concrete_loop = GpuLower::current() - ->idModel() - .idGraph(IdMappingMode::EXACT) - .toGroup(consumer_loop) - ->front() - ->as(); -#endif parallel_concrete_ids_.at(pt) = concrete_loop; } } @@ -875,17 +846,8 @@ void UnswitchPredicate::predicateOn(Expr* tv_expr) { bool first_key_set = false; for (IterDomain* root_id : root_ids) { -#if 1 auto concrete_root_id = gpu_lower->caMap()->getConcreteMappedID( root_id, IdMappingMode::EXACT); -#else - auto concrete_root_id = GpuLower::current() - ->idModel() - .idGraph(IdMappingMode::EXACT) - .toGroup(root_id) - ->front() - ->as(); -#endif if (root_id->isBroadcast()) { continue; From 90e13e3f15a0f88b72a59c95392334c4d1b08974 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 20 Feb 2025 20:02:44 -0800 Subject: [PATCH 25/32] cleanup --- csrc/predicate_compute.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index 08bd337ab61..bd45108c3e5 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -288,15 +288,8 @@ ParallelizedDomainPredicate::getPredicateMap( tv->getLoopDomain().begin(), tv->getLoopDomain().end(), [&](auto tv_id) { -#if 0 return gpu_lower->caMap()->areMapped( loop_id, tv_id, IdMappingMode::EXACT); -#else - return gpu_lower->idModel() - .idGraph(IdMappingMode::EXACT) - .disjointValSets() - .strictAreMapped(loop_id, tv_id); -#endif }); if (it == tv->getLoopDomain().end()) { continue; From 0c0d9b8535f89e561827a1765481d6de306c8463 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 24 Feb 2025 11:41:27 -0800 Subject: [PATCH 26/32] cleanup --- csrc/scheduler/tools/loop_domain_scheduler.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/scheduler/tools/loop_domain_scheduler.h b/csrc/scheduler/tools/loop_domain_scheduler.h index 9a2e8a14764..2859decd6e7 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.h +++ b/csrc/scheduler/tools/loop_domain_scheduler.h @@ -111,8 +111,5 @@ void cancelReshapeInLoopDomains( TensorView* from_tv, bool skip_innermost_id = false); -std::optional getInnermostCancelableReshapePosition( - TensorView* from_tv); - } // namespace scheduler_tools } // namespace nvfuser From 2315b48e1c0fd94719f1b7f419b577a0460d06a4 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 28 Feb 2025 07:48:48 -0800 Subject: [PATCH 27/32] cleanup --- csrc/predicate_compute.cpp | 291 +++++-------------------------------- 1 file changed, 38 insertions(+), 253 deletions(-) diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index 8385548b92b..1500e90a5d5 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -10,8 +10,6 @@ #include #include #include -#include -#include #include #include #include @@ -73,249 +71,52 @@ Val* ParallelizedDomainPredicate::PredicateInfo::getPredicate() const { namespace { -// For a given loop nest represented by a vector of ForLoops, returns -// all unswitched parallel loop IDs that do not require parallel type -// predicates. An ID is considered fully unswitched when all of its -// dependent loop IDs are unswitched. Similarly, a loop is fully -// unswitched when all of its dependent predicated IDs are fully -// unswitched. This information is used to determine if it's safe to -// omit the predicate for a parallel type. -std::vector getUnswitchProtectedParallelLoopIds( - const Expr* expr, +std::unordered_set getNonUnswitchedRootDomains( const std::vector& loops, - ForLoop* unswitched_loop) { - if (unswitched_loop == nullptr) { - return {}; - } - - const auto& id_model = GpuLower::current()->idModel(); - const auto& indexing_graph = - id_model.idGraph(TensorIndexer::traversalGraphType()); - - auto out_tv = ir_utils::getTvOutput(expr); - NVF_ERROR(out_tv != nullptr); - - std::vector loop_ids; - loop_ids.reserve(loops.size()); + size_t unswitched_loop_index) { + std::vector non_unswited_loop_domains; std::transform( loops.begin(), - loops.end(), - std::back_inserter(loop_ids), - [&](ForLoop* loop) { - return getLoopPromotion(loop->iter_domain(), id_model); - }); - - const auto predicate_ids = getPredicateDomains(out_tv, expr); - - const IndexingTraversal::ExprPath predicate_path = - IndexingTraversal::getExprsBetween( - expr, indexing_graph, loop_ids, predicate_ids); - - // All loops that are right of unswitched_loop are also unswitched, - // except when they are parallelized. We don't assign maximum possible - // index values to unswitched parallel loops (e.g., threadIdx.x, not - // blockDim.x - 1), so parallelized loops are not considered - // unswitched for the sake of this analysis. - ValGroups non_unswitch_dep_ids; - bool unswitch_found = false; - for (const auto loop : loops) { - if (loop == unswitched_loop) { - unswitch_found = true; - } - if (!unswitch_found || - isParallelTypeThread(loop->iter_domain()->getParallelType())) { - non_unswitch_dep_ids.pushBack( - indexing_graph.toGroup(loop->iter_domain())); - } - } - - // Find all IDs along the predicate indexing path that depend on the - // non unswitched loop IDs. - for (const auto& [expr_g, dir] : predicate_path) { - const auto inputs = getInputsOfExprGroup(indexing_graph, expr_g, dir); - const auto outputs = getOutputsOfExprGroup(indexing_graph, expr_g, dir); - if (std::any_of(inputs.begin(), inputs.end(), [&](const ValGroup& input) { - return non_unswitch_dep_ids.has(input); - })) { - // Depends on non-unswitched ids - non_unswitch_dep_ids.pushBack(outputs); - } - } + loops.begin() + (int64_t)unswitched_loop_index, + std::back_inserter(non_unswited_loop_domains), + [&](ForLoop* loop) { return loop->iter_domain(); }); - std::vector unswitch_protected_loop_ids; - unswitch_found = false; - for (const auto loop : loops) { - if (loop == unswitched_loop) { - unswitch_found = true; - } - - if (!unswitch_found) { - continue; - } + auto non_unswitched_inputs = + IterVisitor::getInputsTo(non_unswited_loop_domains); - const auto unswitched_loop_id = loop->iter_domain(); - const ParallelType pt = unswitched_loop_id->getParallelType(); + auto non_unswitched_root_doms = + ir_utils::filterByType(non_unswitched_inputs); - // Don't care serial loops - if (!isParallelTypeThread(pt)) { - continue; - } - - // Traverse the predicate indexing path from this unswitched loop - // ID. If any expr along the path also uses any of the non - // unswitched IDs or their dependent IDs, this loop ID is not - // considered fully unswitched. Also, even if unswitched, - // parallelized loop IDs do not use the maximum possible value as - // their indices (e.g., not (blockDim.x - 1) but threadIdx.x), so - // there must be no use of any of other parallel types than this - // parallel type. - - // Keep track of IDs that have dependencies with unswitched_loop_id - ValGroups unswitch_dep_ids; - unswitch_dep_ids.pushBack(indexing_graph.toGroup(unswitched_loop_id)); - - bool protected_by_unswitch = true; - - for (const auto& [expr_g, dir] : predicate_path) { - const auto inputs = getInputsOfExprGroup(indexing_graph, expr_g, dir); - const auto outputs = getOutputsOfExprGroup(indexing_graph, expr_g, dir); - - // If none of the inputs depends on unswitched_loop_id and its - // dependents, this expr should not matter. - if (std::none_of( - inputs.begin(), inputs.end(), [&](const ValGroup& input) { - return unswitch_dep_ids.has(input); - })) { - continue; - } - - // If any of the non unswitched IDs is used, this is not - // protected. Note that non_unswitch_dep_ids contains all - // parallelized unswitched IDs and their dependents, including - // unswitched_loop_id itself. Use of unswitched_loop_id and its - // dependents should not make unswitched_loop_id not fully - // unswitched. - if (std::any_of(inputs.begin(), inputs.end(), [&](const ValGroup& input) { - return non_unswitch_dep_ids.has(input) && - !unswitch_dep_ids.has(input); - })) { - protected_by_unswitch = false; - break; - } - - // Continue to keep track of the dependencies from unswitched_loop_id - unswitch_dep_ids.pushBack(outputs); - } - - if (protected_by_unswitch) { - unswitch_protected_loop_ids.push_back(unswitched_loop_id); - } - } - - return unswitch_protected_loop_ids; -} + std::unordered_set non_unswitched_concrete_root_domains; -std::vector getFullyUnswitchedLoopIds( - const Expr* expr, - const std::vector& loops, - ForLoop* unswitched_loop) { - if (unswitched_loop == nullptr) { - return {}; - } - - const auto& id_model = GpuLower::current()->idModel(); - const auto& indexing_graph = - id_model.idGraph(TensorIndexer::traversalGraphType()); - - auto out_tv = ir_utils::getTvOutput(expr); - NVF_ERROR(out_tv != nullptr); - - std::vector loop_ids; - loop_ids.reserve(loops.size()); std::transform( - loops.begin(), - loops.end(), - std::back_inserter(loop_ids), - [&](ForLoop* loop) { - const auto& loop_group = - id_model.idGraph(IdMappingMode::LOOP).toGroup(loop->iter_domain()); - auto promotion_it = id_model.loopPromotionMap().find(loop_group); - NVF_ERROR( - promotion_it != id_model.loopPromotionMap().end(), - "Loop promotion not found for ", - loop->iter_domain()->toString()); - return promotion_it->second; + non_unswitched_root_doms.begin(), + non_unswitched_root_doms.end(), + std::inserter( + non_unswitched_concrete_root_domains, + non_unswitched_concrete_root_domains.end()), + [&](auto root_dom) { + return GpuLower::current()->caMap()->getConcreteMappedID( + root_dom, IdMappingMode::EXACT); }); - const auto predicate_ids = getPredicateDomains(out_tv, expr); - - const IndexingTraversal::ExprPath predicate_path = - IndexingTraversal::getExprsBetween( - expr, indexing_graph, loop_ids, predicate_ids); - - ValGroups non_unswitch_dep_ids; - std::vector unswitched_loop_ids; - bool unswitch_found = false; - for (const auto loop : loops) { - if (loop == unswitched_loop) { - unswitch_found = true; - } - if (unswitch_found) { - unswitched_loop_ids.push_back(loop->iter_domain()); - } else { - non_unswitch_dep_ids.pushBack( - indexing_graph.toGroup(loop->iter_domain())); - } - } - - for (const auto& [expr_g, dir] : predicate_path) { - const auto inputs = getInputsOfExprGroup(indexing_graph, expr_g, dir); - const auto outputs = getOutputsOfExprGroup(indexing_graph, expr_g, dir); - if (std::any_of(inputs.begin(), inputs.end(), [&](const ValGroup& input) { - return non_unswitch_dep_ids.has(input); - })) { - // Depends on non-unswitched ids - non_unswitch_dep_ids.pushBack(outputs); - } - } - - // If none of unswitched_loop_ids is used with the non-unswitched - // loop ids, - - std::vector fully_unswitched_loop_ids; - for (auto unswitched_loop_id : unswitched_loop_ids) { - if (!isParallelTypeThread(unswitched_loop_id->getParallelType())) { - continue; - } - - ValGroups unswitch_dep_ids; - unswitch_dep_ids.pushBack(indexing_graph.toGroup(unswitched_loop_id)); - - bool conflict_found = false; - for (const auto& [expr_g, dir] : predicate_path) { - const auto inputs = getInputsOfExprGroup(indexing_graph, expr_g, dir); - const auto outputs = getOutputsOfExprGroup(indexing_graph, expr_g, dir); - if (std::none_of( - inputs.begin(), inputs.end(), [&](const ValGroup& input) { - return unswitch_dep_ids.has(input); - })) { - continue; - } + return non_unswitched_concrete_root_domains; +} - if (std::any_of(inputs.begin(), inputs.end(), [&](const ValGroup& input) { - return non_unswitch_dep_ids.has(input); - })) { - conflict_found = true; - break; - } - } +bool isFullyUnswitched( + IterDomain* loop_id, + const std::unordered_set& non_unswitched_root_domains) { + auto root_vals = IterVisitor::getInputsTo({loop_id}); - if (!conflict_found) { - fully_unswitched_loop_ids.push_back(unswitched_loop_id); - } - } + auto root_domains = ir_utils::filterByType(root_vals); - return fully_unswitched_loop_ids; + return std::none_of( + root_domains.begin(), root_domains.end(), [&](auto root_dom) { + auto concrete_root_dom = + GpuLower::current()->caMap()->getConcreteMappedID( + root_dom, IdMappingMode::EXACT); + return non_unswitched_root_domains.count(concrete_root_dom) > 0; + }); } } // namespace @@ -346,15 +147,13 @@ ParallelizedDomainPredicate::getPredicateMap( bool within_unswitch = false; std::unordered_set non_unswitched_root_domains; - auto unswitch_protected_loop_ids = - getUnswitchProtectedParallelLoopIds(expr, loops, unswitched_loop); - for (const auto i : c10::irange(loops.size())) { auto loop = loops[i]; // Parallel dimensions need not be predicated if fully unswitched. if (loop == unswitched_loop) { within_unswitch = true; + non_unswitched_root_domains = getNonUnswitchedRootDomains(loops, i); } auto loop_id = loop->iter_domain(); @@ -367,23 +166,9 @@ ParallelizedDomainPredicate::getPredicateMap( } auto parallel_dim = gpu_lower->parallelDimensionMap().getRaw(loop_ptype); - // If protected by unswitch, the unswitch predicate is enough without - // predicating the parallel type. For example, suppose a logical - // ID is inner split by a factor of K and both of the two outputs - // are unswitched. Also suppose the inner output IDs is - // parallelized with TIDx but the other output is not. The logical - // ID would be predicated by something like: - // - // threadIdx.x + (ceilDiv(N, K) - 1) * K < N - // - // where N is the extent of the logical ID. As you can see, since - // the other output is assigned with the maximum index, this - // predicate is sufficient even when blockDim.x > K. + // Parallel dimensions need not be predicated if fully unswitched. if (within_unswitch && - std::find( - unswitch_protected_loop_ids.begin(), - unswitch_protected_loop_ids.end(), - loop_id) != unswitch_protected_loop_ids.end()) { + isFullyUnswitched(loop_id, non_unswitched_root_domains)) { continue; } @@ -553,7 +338,7 @@ UnswitchPredicateKey::UnswitchPredicateKey( } // Find the corresponding concrete id for each parallel type - for (IterDomain* consumer_loop : parallelized_consumer_loop_ids) { + for (auto consumer_loop : parallelized_consumer_loop_ids) { auto pt = consumer_loop->getParallelType(); auto concrete_loop = GpuLower::current()->caMap()->getConcreteMappedID( consumer_loop, IdMappingMode::EXACT); @@ -943,7 +728,7 @@ void UnswitchPredicate::predicateOn(Expr* tv_expr) { UnswitchPredicateKey first_key; bool first_key_set = false; - for (IterDomain* root_id : root_ids) { + for (auto root_id : root_ids) { auto concrete_root_id = gpu_lower->caMap()->getConcreteMappedID( root_id, IdMappingMode::EXACT); From 015c730f3132fafaf5de86ff83c979625c1de22e Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 28 Feb 2025 14:54:19 -0800 Subject: [PATCH 28/32] update --- csrc/predicate_compute.cpp | 188 +++++++++++++++++++++++++++++-------- 1 file changed, 150 insertions(+), 38 deletions(-) diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index 1500e90a5d5..9c8ae86a880 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -10,6 +10,8 @@ #include #include #include +#include +#include #include #include #include @@ -71,52 +73,146 @@ Val* ParallelizedDomainPredicate::PredicateInfo::getPredicate() const { namespace { -std::unordered_set getNonUnswitchedRootDomains( +// For a given loop nest represented by a vector of ForLoops, returns +// all unswitched parallel loop IDs that do not require parallel type +// predicates. An ID is considered fully unswitched when all of its +// dependent loop IDs are unswitched. Similarly, a loop is fully +// unswitched when all of its dependent predicated IDs are fully +// unswitched. This information is used to determine if it's safe to +// omit the predicate for a parallel type. +std::vector getUnswitchProtectedParallelLoopIds( + const Expr* expr, const std::vector& loops, - size_t unswitched_loop_index) { - std::vector non_unswited_loop_domains; + ForLoop* unswitched_loop) { + if (unswitched_loop == nullptr) { + return {}; + } + + const auto& id_model = GpuLower::current()->idModel(); + const auto& indexing_graph = + id_model.idGraph(TensorIndexer::traversalGraphType()); + + auto out_tv = ir_utils::getTvOutput(expr); + NVF_ERROR(out_tv != nullptr); + + std::vector loop_ids; + loop_ids.reserve(loops.size()); std::transform( loops.begin(), - loops.begin() + (int64_t)unswitched_loop_index, - std::back_inserter(non_unswited_loop_domains), - [&](ForLoop* loop) { return loop->iter_domain(); }); + loops.end(), + std::back_inserter(loop_ids), + [&](ForLoop* loop) { + return getLoopPromotion(loop->iter_domain(), id_model); + }); - auto non_unswitched_inputs = - IterVisitor::getInputsTo(non_unswited_loop_domains); + const auto predicate_ids = getPredicateDomains(out_tv, expr); - auto non_unswitched_root_doms = - ir_utils::filterByType(non_unswitched_inputs); + const IndexingTraversal::ExprPath predicate_path = + IndexingTraversal::getExprsBetween( + expr, indexing_graph, loop_ids, predicate_ids); - std::unordered_set non_unswitched_concrete_root_domains; + // All loops that are right of unswitched_loop are also unswitched, + // except when they are parallelized. We don't assign maximum possible + // index values to unswitched parallel loops (e.g., threadIdx.x, not + // blockDim.x - 1), so parallelized loops are not considered + // unswitched for the sake of this analysis. + ValGroups non_unswitch_dep_ids; + bool unswitch_found = false; + for (const auto loop : loops) { + if (loop == unswitched_loop) { + unswitch_found = true; + } + if (!unswitch_found || + isParallelTypeThread(loop->iter_domain()->getParallelType())) { + non_unswitch_dep_ids.pushBack( + indexing_graph.toGroup(loop->iter_domain())); + } + } - std::transform( - non_unswitched_root_doms.begin(), - non_unswitched_root_doms.end(), - std::inserter( - non_unswitched_concrete_root_domains, - non_unswitched_concrete_root_domains.end()), - [&](auto root_dom) { - return GpuLower::current()->caMap()->getConcreteMappedID( - root_dom, IdMappingMode::EXACT); - }); + // Find all IDs along the predicate indexing path that depend on the + // non unswitched loop IDs. + for (const auto& [expr_g, dir] : predicate_path) { + const auto inputs = getInputsOfExprGroup(indexing_graph, expr_g, dir); + const auto outputs = getOutputsOfExprGroup(indexing_graph, expr_g, dir); + if (std::any_of(inputs.begin(), inputs.end(), [&](const ValGroup& input) { + return non_unswitch_dep_ids.has(input); + })) { + // Depends on non-unswitched ids + non_unswitch_dep_ids.pushBack(outputs); + } + } - return non_unswitched_concrete_root_domains; -} + std::vector unswitch_protected_loop_ids; + unswitch_found = false; + for (const auto loop : loops) { + if (loop == unswitched_loop) { + unswitch_found = true; + } -bool isFullyUnswitched( - IterDomain* loop_id, - const std::unordered_set& non_unswitched_root_domains) { - auto root_vals = IterVisitor::getInputsTo({loop_id}); + if (!unswitch_found) { + continue; + } - auto root_domains = ir_utils::filterByType(root_vals); + const auto unswitched_loop_id = loop->iter_domain(); + const ParallelType pt = unswitched_loop_id->getParallelType(); - return std::none_of( - root_domains.begin(), root_domains.end(), [&](auto root_dom) { - auto concrete_root_dom = - GpuLower::current()->caMap()->getConcreteMappedID( - root_dom, IdMappingMode::EXACT); - return non_unswitched_root_domains.count(concrete_root_dom) > 0; - }); + // Don't care serial loops + if (!isParallelTypeThread(pt)) { + continue; + } + + // Traverse the predicate indexing path from this unswitched loop + // ID. If any expr along the path also uses any of the non + // unswitched IDs or their dependent IDs, this loop ID is not + // considered fully unswitched. Also, even if unswitched, + // parallelized loop IDs do not use the maximum possible value as + // their indices (e.g., not (blockDim.x - 1) but threadIdx.x), so + // there must be no use of any of other parallel types than this + // parallel type. + + // Keep track of IDs that have dependencies with unswitched_loop_id + ValGroups unswitch_dep_ids; + unswitch_dep_ids.pushBack(indexing_graph.toGroup(unswitched_loop_id)); + + bool protected_by_unswitch = true; + + for (const auto& [expr_g, dir] : predicate_path) { + const auto inputs = getInputsOfExprGroup(indexing_graph, expr_g, dir); + const auto outputs = getOutputsOfExprGroup(indexing_graph, expr_g, dir); + + // If none of the inputs depends on unswitched_loop_id and its + // dependents, this expr should not matter. + if (std::none_of( + inputs.begin(), inputs.end(), [&](const ValGroup& input) { + return unswitch_dep_ids.has(input); + })) { + continue; + } + + // If any of the non unswitched IDs is used, this is not + // protected. Note that non_unswitch_dep_ids contains all + // parallelized unswitched IDs and their dependents, including + // unswitched_loop_id itself. Use of unswitched_loop_id and its + // dependents should not make unswitched_loop_id not fully + // unswitched. + if (std::any_of(inputs.begin(), inputs.end(), [&](const ValGroup& input) { + return non_unswitch_dep_ids.has(input) && + !unswitch_dep_ids.has(input); + })) { + protected_by_unswitch = false; + break; + } + + // Continue to keep track of the dependencies from unswitched_loop_id + unswitch_dep_ids.pushBack(outputs); + } + + if (protected_by_unswitch) { + unswitch_protected_loop_ids.push_back(unswitched_loop_id); + } + } + + return unswitch_protected_loop_ids; } } // namespace @@ -147,13 +243,15 @@ ParallelizedDomainPredicate::getPredicateMap( bool within_unswitch = false; std::unordered_set non_unswitched_root_domains; + auto unswitch_protected_loop_ids = + getUnswitchProtectedParallelLoopIds(expr, loops, unswitched_loop); + for (const auto i : c10::irange(loops.size())) { auto loop = loops[i]; // Parallel dimensions need not be predicated if fully unswitched. if (loop == unswitched_loop) { within_unswitch = true; - non_unswitched_root_domains = getNonUnswitchedRootDomains(loops, i); } auto loop_id = loop->iter_domain(); @@ -166,9 +264,23 @@ ParallelizedDomainPredicate::getPredicateMap( } auto parallel_dim = gpu_lower->parallelDimensionMap().getRaw(loop_ptype); - // Parallel dimensions need not be predicated if fully unswitched. + // If protected by unswitch, the unswitch predicate is enough without + // predicating the parallel type. For example, suppose a logical + // ID is inner split by a factor of K and both of the two outputs + // are unswitched. Also suppose the inner output IDs is + // parallelized with TIDx but the other output is not. The logical + // ID would be predicated by something like: + // + // threadIdx.x + (ceilDiv(N, K) - 1) * K < N + // + // where N is the extent of the logical ID. As you can see, since + // the other output is assigned with the maximum index, this + // predicate is sufficient even when blockDim.x > K. if (within_unswitch && - isFullyUnswitched(loop_id, non_unswitched_root_domains)) { + std::find( + unswitch_protected_loop_ids.begin(), + unswitch_protected_loop_ids.end(), + loop_id) != unswitch_protected_loop_ids.end()) { continue; } From ea5bc585f4a2891cd91293fd24616d9496aff386 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 3 Mar 2025 11:12:47 -0800 Subject: [PATCH 29/32] cleanup --- csrc/scheduler/resize.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/csrc/scheduler/resize.cpp b/csrc/scheduler/resize.cpp index ff4cfa76e4e..2ba6dca051a 100644 --- a/csrc/scheduler/resize.cpp +++ b/csrc/scheduler/resize.cpp @@ -471,11 +471,8 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { } if (vec_factor > 1) { - // auto vec_ref_tv = largest_input != nullptr ? largest_input : - // ref_tv; - auto vec_ref_tv = ref_tv; const auto tvs_to_vectorize = - scheduler_utils::getInputsOutputsWithInnerDim(vec_ref_tv, true, true); + scheduler_utils::getInputsOutputsWithInnerDim(ref_tv, true, true); for (auto tv_to_vectorize : tvs_to_vectorize) { if (tv_to_vectorize->isFusionInput()) { for (auto consumer_tv : ir_utils::consumerTvsOf(tv_to_vectorize)) { From c1c33b07b2b8cf60e97ac558e5137e03944d26bc Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 3 Mar 2025 11:15:50 -0800 Subject: [PATCH 30/32] temporarily move back assertions --- csrc/scheduler/tools/loop_domain_scheduler.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/scheduler/tools/loop_domain_scheduler.cpp b/csrc/scheduler/tools/loop_domain_scheduler.cpp index c79ad79245e..ccab7fce0e1 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.cpp +++ b/csrc/scheduler/tools/loop_domain_scheduler.cpp @@ -528,10 +528,12 @@ void scheduleLoopDomainsBy( Direction replay_dir_tv = Direction::Undefined; if (replay_dir != Direction::Backward && input_ids.size() == transform->inputs().size()) { + NVF_ERROR(output_ids.empty()); replay_dir_tv = Direction::Forward; } else if ( replay_dir != Direction::Forward && output_ids.size() == transform->outputs().size()) { + NVF_ERROR(input_ids.empty()); replay_dir_tv = Direction::Backward; } else { // Replay not possible since none of inputs nor outputs are connected with From 0c5c6f86200ca01d17dcc0f5a10866aa8eab42bc Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 4 Mar 2025 02:17:12 -0800 Subject: [PATCH 31/32] Remove assertions --- csrc/scheduler/tools/loop_domain_scheduler.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/csrc/scheduler/tools/loop_domain_scheduler.cpp b/csrc/scheduler/tools/loop_domain_scheduler.cpp index ccab7fce0e1..c79ad79245e 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.cpp +++ b/csrc/scheduler/tools/loop_domain_scheduler.cpp @@ -528,12 +528,10 @@ void scheduleLoopDomainsBy( Direction replay_dir_tv = Direction::Undefined; if (replay_dir != Direction::Backward && input_ids.size() == transform->inputs().size()) { - NVF_ERROR(output_ids.empty()); replay_dir_tv = Direction::Forward; } else if ( replay_dir != Direction::Forward && output_ids.size() == transform->outputs().size()) { - NVF_ERROR(input_ids.empty()); replay_dir_tv = Direction::Backward; } else { // Replay not possible since none of inputs nor outputs are connected with From 875a92da6c7070d1c76c5d9378d7b889d130e5c4 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 5 Mar 2025 14:43:41 -0800 Subject: [PATCH 32/32] update --- csrc/scheduler/tools/loop_domain_scheduler.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/csrc/scheduler/tools/loop_domain_scheduler.cpp b/csrc/scheduler/tools/loop_domain_scheduler.cpp index c79ad79245e..a2dfac7d4fa 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.cpp +++ b/csrc/scheduler/tools/loop_domain_scheduler.cpp @@ -522,9 +522,8 @@ void scheduleLoopDomainsBy( } } - // It should be either: all of the inputs found and none of the - // outputs found, or none of the inputs found and all of the - // outputs found. + // If all of the inputs are found, the tranform expr is replayed as + // a forward op. Direction replay_dir_tv = Direction::Undefined; if (replay_dir != Direction::Backward && input_ids.size() == transform->inputs().size()) {