From ad6cec1bb5595a37baa723c0792775beeeece8c2 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 24 Feb 2025 14:47:21 -0800 Subject: [PATCH 1/5] WIP --- csrc/graph_traversal_bak.h | 307 ++++++++++++++++++ csrc/scheduler/resize.cpp | 7 +- .../scheduler/tools/loop_domain_scheduler.cpp | 27 +- csrc/scheduler/tools/loop_domain_scheduler.h | 8 +- tests/cpp/test_resize.cpp | 69 ++++ 5 files changed, 413 insertions(+), 5 deletions(-) create mode 100644 csrc/graph_traversal_bak.h diff --git a/csrc/graph_traversal_bak.h b/csrc/graph_traversal_bak.h new file mode 100644 index 00000000000..0fd483942f4 --- /dev/null +++ b/csrc/graph_traversal_bak.h @@ -0,0 +1,307 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +namespace nvfuser { + +template < + typename ExprT, + typename ValT, + typename DefinitionT, + typename UsesT, + typename InputsT, + typename OutputsT> +class FindAllPaths { + public: + using ExprType = ExprT; + using ValType = ValT; + using NodeType = std::variant; + using ExprPath = std::vector>; + using InputsType = InputsT; + using OutputsType = OutputsT; + + FindAllPaths(DefinitionT definition, + UsesT uses, + InputsT inputs, + OutputsT outputs, + std::vector from, + std::vector to, + bool require_all_to_visited = true, + Direction allowed_direction = Direction::Undefined) + : definition_(std::move(definition)), + uses_(std::move(uses)), + inputs_(std::move(inputs)), + outputs_(std::move(outputs)), + from_(std::move(from)), + to_(std::move(to)), + require_all_to_visited_(require_all_to_visited), + allowed_direction_(allowed_direction) {} + + virtual VectorOfUniqueEntries get() { + std::deque + + + return VectorOfUniqueEntries{}; + } + + // Traverse from from_ to to_, recording each taken + // path to generate the shortest path after the travesal + virtual void traverse() { +#if 0 + for (const auto& n : from_) { + setVisited(n); + addNewNeighbors(n); + } + + while (!allToNodesVisited()) { + bool something_was_processed = false; + std::deque not_ready; + while (!allToNodesVisited() && !to_visit_.empty()) { + const auto n = to_visit_.front(); + to_visit_.pop_front(); + + if (isVisited(n)) { + continue; + } + + auto ready_direction = isReady(n); + if (!ready_direction.has_value()) { + // To stop an infinite loop, the not-ready node is not moved + // back to the to_visit_ queue but kept in the separate + // queue. This way, if all nodes in to_visit_ are not ready, + // the queue would eventually become empty, which would then + // break the inner while loop. The something_was_processed + // flag is used to remember if there's any progress. + not_ready.emplace_back(n); + continue; + } + + // Visit this node and add its neighbors to to_visit if not + // visited yet + setVisited(n); + setPrevGroups(n, *ready_direction); + addNewNeighbors(n); + something_was_processed = true; + } + + // If nothing was processed, break out of the loop + if (!something_was_processed) { + break; + } + + // Something was processed. Redo the traversal. + to_visit_.insert(to_visit_.end(), not_ready.begin(), not_ready.end()); + } + + if (require_all_to_visited_ && !allToNodesVisited()) { + std::stringstream ss; + for (const auto& to : to_) { + if (!isVisited(to)) { + ss << " " << toString(to); + if (const ExprT* e = std::get_if(&to)) { + ss << " " << toString(*e); + } + } + } + ss << " (from: "; + for (const auto& from : from_) { + ss << " " << toString(from); + if (const ExprT* e = std::get_if(&from)) { + ss << " " << toString(*e); + } + } + ss << ")"; + ss << ", visited: ("; + for (const auto& visited : visited_) { + if (const ValT* v = std::get_if(&visited)) { + ss << " " << toString(visited); + } + } + ss << ")"; + NVF_THROW("BFS traversal could not visit some nodes: ", ss.str()); + } +#endif + } + + // Check if a node is ready to visit. If yes, return the direction + // and the prev nodes that should be visited before the given node + // is visited. + virtual std::optional>> isReady( + const NodeType& node) const { + if (const ExprT* e = std::get_if(&node)) { + return isReady(*e); + } else if (const ValT* v = std::get_if(&node)) { + return isReady(*v); + } else { + NVF_THROW(); + } + } + + // Check if an ExprT is ready to visit. Either all of its inputs + // or all of outputs must have their dependencies satisfied. If + // ready because the inputs are already visited, return + // Direction::Forward and all the input nodes. If ready because the + // outputs are ready, return Direction::Backward and all the output nodes. + virtual std::optional>> isReady( + const ExprT& expr) const { + // Either all inputs or all outputs must have been visited + decltype(auto) inputs = inputs_(expr); + if (!inputs.empty() && allowed_direction_ != Direction::Backward && + std::all_of( + inputs.begin(), inputs.end(), [&](const ValT& input) -> bool { + return isDependencySatisfied(input); + })) { + std::vector prev_nodes; + std::copy_if( + inputs.begin(), + inputs.end(), + std::back_inserter(prev_nodes), + [&](const ValT& input) -> bool { return isVisited(input); }); + return std::make_pair(Direction::Forward, prev_nodes); + } + + decltype(auto) outputs = outputs_(expr); + if (!outputs.empty() && allowed_direction_ != Direction::Forward && + std::all_of( + outputs.begin(), outputs.end(), [&](const ValT& output) -> bool { + return isDependencySatisfied(output); + })) { + std::vector prev_nodes; + std::copy_if( + outputs.begin(), + outputs.end(), + std::back_inserter(prev_nodes), + [&](const ValT& output) -> bool { return isVisited(output); }); + return std::make_pair(Direction::Backward, prev_nodes); + } + + return std::nullopt; + } + + // Check if a val is ready to visit. Either its defining or use + // expr must have its dependency satisfied. If ready because + // there's a visited defining expr, return Direction::Forward and + // the defining expr. If ready because there's a visited use expr, return + // Direction::Backward and the use expr. + virtual std::optional>> isReady( + const ValT& v) const { + // In the case of Val, requires just one def or use expr. + // Check if any use is visited + decltype(auto) uses = uses_(v); + if (!uses.empty()) { + auto it = std::find_if( + uses.begin(), uses.end(), [&](const ExprT& use_e) -> bool { + return isDependencySatisfied(use_e); + }); + if (it != uses.end()) { + return std::make_pair(Direction::Backward, std::vector{*it}); + } + } + // Check if any def is visited + decltype(auto) def = definition_(v); + if (!def.empty()) { + auto it = + std::find_if(def.begin(), def.end(), [&](const ExprT& def_e) -> bool { + return isDependencySatisfied(def_e); + }); + if (it != def.end()) { + return std::make_pair(Direction::Forward, std::vector{*it}); + } + } + + return std::nullopt; + } + + // If another node depends on a given node, check if that + // dependency is considered satisfied. If the given node is already + // visited, that should mean the dependency is satisfied. + virtual bool isDependencySatisfied(const NodeType& dependency) const { + return isVisited(dependency); + } + + // Check if a given node is already visited + virtual bool isVisited(const NodeType& node) const { + return visited_.find(node) != visited_.end(); + } + + // Mark a node as visited + virtual void setVisited(const NodeType& node) { + visited_.emplace(node); + } + + // Add new neighbors of a given node to the to_visit list + virtual void addNewNeighbors(const NodeType& node) { + auto add_to_visit_list = [&](const NodeType& n) -> void { + if (isVisited(n) || excludeFromTraversal(n)) { + return; + } + to_visit_.emplace_back(n); + }; + + if (const ExprT* e = std::get_if(&node)) { + if (allowed_direction_ == Direction::Backward || + allowed_direction_ == Direction::Undefined) { + for (const auto& v : inputs_(*e)) { + add_to_visit_list(v); + } + } + if (allowed_direction_ == Direction::Forward || + allowed_direction_ == Direction::Undefined) { + for (const auto& v : outputs_(*e)) { + add_to_visit_list(v); + } + } + } else if (const ValT* v = std::get_if(&node)) { + if (allowed_direction_ == Direction::Forward || + allowed_direction_ == Direction::Undefined) { + for (const auto& e : uses_(*v)) { + add_to_visit_list(e); + } + } + if (allowed_direction_ == Direction::Backward || + allowed_direction_ == Direction::Undefined) { + for (const auto& e : definition_(*v)) { + add_to_visit_list(e); + } + } + } else { + NVF_THROW(); + } + } + + // Check if all to_ are visited + virtual bool allToNodesVisited() const { + return std::all_of( + to_.begin(), to_.end(), [&](const NodeType& node) -> bool { + return isVisited(node); + }); + }; + + // Hook to exclude certain graph nodes. See IndexingTraversal for a + // concrete example + virtual bool excludeFromTraversal(const NodeType& node) const { + return false; + } + + protected: + const DefinitionT definition_; + const UsesT uses_; + const InputsT inputs_; + const OutputsT outputs_; + const std::vector from_; + const std::vector to_; + std::deque to_visit_; + std::unordered_set visited_; + bool require_all_to_visited_ = true; + Direction allowed_direction_ = Direction::Undefined; +}; + + +} // namespace nvfuser diff --git a/csrc/scheduler/resize.cpp b/csrc/scheduler/resize.cpp index 1c198aee8dc..c6d434a7b69 100644 --- a/csrc/scheduler/resize.cpp +++ b/csrc/scheduler/resize.cpp @@ -237,7 +237,9 @@ std::unique_ptr ResizeScheduler::computeHeuristics( // Before applying the vectorization split, any reshape transform of // the largest input will be cancelled whenever possible, so the // largest input is used as the reference of vectorization. - auto vec_ref_tv = largest_input != nullptr ? largest_input : ref_tv; + // auto vec_ref_tv = largest_input != nullptr ? largest_input : + // ref_tv; + auto vec_ref_tv = ref_tv; // Only consider the innermost dimension to vectorize for now. // TODO: Consider vectorizing merged IDs, not just the innermost @@ -300,7 +302,8 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { // The tensors are going to be reordered to align with the largest // input. To make it work, merge operations for reshape should be // cancelled. - scheduler_tools::cancelReshapeInLoopDomains(largest_input); + scheduler_tools::cancelReshapeInLoopDomains( + largest_input, /*skip_innermost_id=*/true); } for (auto expr : fusion->exprs()) { diff --git a/csrc/scheduler/tools/loop_domain_scheduler.cpp b/csrc/scheduler/tools/loop_domain_scheduler.cpp index 16734cfb294..ccab7fce0e1 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.cpp +++ b/csrc/scheduler/tools/loop_domain_scheduler.cpp @@ -99,6 +99,8 @@ class LoopDomainSchedulerReplayTransform : OptInConstDispatch { const std::vector& output_ids_; }; +// Replay a given IterDomain transform expression on the loop domain +// of a given tensor using specified loop IDs as its inputs. class ReplayForwardTransformOnLoopDomain : OptInConstDispatch { public: static void replayAs( @@ -541,7 +543,11 @@ void scheduleLoopDomainsBy( // When the direction is forward, the TensorView transform // APIs, e.g., TensorView::split, can be used, which doesn't need - // to use TensorView::setLoopDomain. + // to use TensorView::setLoopDomain. This is important as + // setLoopDomain may result in losing extra IDs added by prior + // scheduleLoopDomain calls, which was indeed the case with the + // Llama 3 RoPE backward (see also + // https://github.com/NVIDIA/Fuser/issues/3571). if (replay_dir_tv == Direction::Forward) { ReplayForwardTransformOnLoopDomain::replayAs(tv, input_ids, transform); continue; @@ -582,7 +588,7 @@ void scheduleLoopDomainsBy( return; } -void cancelReshapeInLoopDomains(TensorView* from_tv) { +void cancelReshapeInLoopDomains(TensorView* from_tv, bool skip_innermost_id) { Fusion* fusion = from_tv->fusion(); IdModel id_model(fusion, /*build_graphs=*/false); id_model.buildExactGraph(); @@ -677,6 +683,18 @@ void cancelReshapeInLoopDomains(TensorView* from_tv) { {reshape_out->getLogicalDomain().begin(), reshape_out->getLogicalDomain().end()}); + std::unordered_set reshape_exprs_with_innermost_logical_id_set; + if (skip_innermost_id) { + auto reshape_exprs_with_innermost_logical_id = + DependencyCheck::getAllExprsBetween( + {reshape_out->getRootDomain().begin(), + reshape_out->getRootDomain().end()}, + {reshape_out->getLogicalDomain().back()}); + reshape_exprs_with_innermost_logical_id_set = { + reshape_exprs_with_innermost_logical_id.begin(), + reshape_exprs_with_innermost_logical_id.end()}; + } + auto reshape_out_loop_domain = reshape_out->getLoopDomain(); for (auto reshape_exprs_it = reshape_exprs.rbegin(); @@ -684,6 +702,11 @@ void cancelReshapeInLoopDomains(TensorView* from_tv) { ++reshape_exprs_it) { auto reshape_expr = *reshape_exprs_it; + if (skip_innermost_id && + reshape_exprs_with_innermost_logical_id_set.count(reshape_expr)) { + continue; + } + // If any of the output IDs of reshape_expr is not found in // cancellable_ids, that means the expr cannot be cancelled. if (std::any_of( diff --git a/csrc/scheduler/tools/loop_domain_scheduler.h b/csrc/scheduler/tools/loop_domain_scheduler.h index fa0d4e0d2ae..cedfbb06d19 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.h +++ b/csrc/scheduler/tools/loop_domain_scheduler.h @@ -107,7 +107,13 @@ void scheduleLoopDomainsBy( // iter domain is reduced, the split needs to remain. If a reshape // only consists of merge transforms, cancellation should be possible, // but that is not currently supported. -void cancelReshapeInLoopDomains(TensorView* from_tv); +// +// When the skip_innermost_id flag is true, any reshape that involves +// innermost logical ID is not canceled even when it's technically +// possible. This is a WAR for the resize scheduler. +void cancelReshapeInLoopDomains( + TensorView* from_tv, + bool skip_innermost_id = false); } // namespace scheduler_tools } // namespace nvfuser diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index 5cb98fe3e2b..a29416b8603 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -5822,4 +5822,73 @@ TEST_F(ResizeTest, DoNotFuseResizeAndIndexOps) { } } +TEST_F(ResizeTest, VectorizeInnermostWithReshapeSplit) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + std::vector shape1{128L * 16L}; + std::vector shape2{shape1[0] / 2L, 2L}; + + auto tv0 = makeContigConcreteTensor(shape1); + fusion.addInput(tv0); + + auto tv1 = sin(tv0); + auto tv2 = reshape(tv1, shape1, shape2); + auto tv3 = slice( + tv2, + {{IrBuilder::create(0L), IrBuilder::create(2L)}, + {IrBuilder::create(0L), IrBuilder::create(shape2[1])}}); + fusion.addOutput(tv3); + + fusion.printMath(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape1, options); + + // Not sure why, but MarkAliasesPreparePass inserts segment_set + // after reshape + preseg_passes::OptimizationPassGuard + guard(false); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t0}); + testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__); +} + +TEST_F(ResizeTest, VectorizeInnermostWithReshapeMerge) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + std::vector shape2{16, 128L * 16L}; + std::vector shape1{16, shape2[1] / 2L, 2L}; + + auto tv0 = makeContigConcreteTensor(shape1); + fusion.addInput(tv0); + + auto tv1 = sin(tv0); + // [16, 128 * 16 / 2, 2] -> [16, 128 * 16]. Cancellable reshape. + auto tv2 = reshape(tv1, shape1, shape2); + auto tv3 = slice( + tv2, + {{IrBuilder::create(0L), IrBuilder::create(2L)}, + {IrBuilder::create(0L), IrBuilder::create(shape2[1])}}); + fusion.addOutput(tv3); + + fusion.printMath(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape1, options); + + // Not sure why, but MarkAliasesPreparePass inserts segment_set + // after reshape + preseg_passes::OptimizationPassGuard + guard(false); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t0}); + testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__); +} + } // namespace nvfuser From e0d4ebae8488f7b03d0107c193312354bfe15751 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 24 Feb 2025 14:50:33 -0800 Subject: [PATCH 2/5] delete --- csrc/graph_traversal_bak.h | 307 ------------------------------------- 1 file changed, 307 deletions(-) delete mode 100644 csrc/graph_traversal_bak.h diff --git a/csrc/graph_traversal_bak.h b/csrc/graph_traversal_bak.h deleted file mode 100644 index 0fd483942f4..00000000000 --- a/csrc/graph_traversal_bak.h +++ /dev/null @@ -1,307 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#pragma once - -#include - -namespace nvfuser { - -template < - typename ExprT, - typename ValT, - typename DefinitionT, - typename UsesT, - typename InputsT, - typename OutputsT> -class FindAllPaths { - public: - using ExprType = ExprT; - using ValType = ValT; - using NodeType = std::variant; - using ExprPath = std::vector>; - using InputsType = InputsT; - using OutputsType = OutputsT; - - FindAllPaths(DefinitionT definition, - UsesT uses, - InputsT inputs, - OutputsT outputs, - std::vector from, - std::vector to, - bool require_all_to_visited = true, - Direction allowed_direction = Direction::Undefined) - : definition_(std::move(definition)), - uses_(std::move(uses)), - inputs_(std::move(inputs)), - outputs_(std::move(outputs)), - from_(std::move(from)), - to_(std::move(to)), - require_all_to_visited_(require_all_to_visited), - allowed_direction_(allowed_direction) {} - - virtual VectorOfUniqueEntries get() { - std::deque - - - return VectorOfUniqueEntries{}; - } - - // Traverse from from_ to to_, recording each taken - // path to generate the shortest path after the travesal - virtual void traverse() { -#if 0 - for (const auto& n : from_) { - setVisited(n); - addNewNeighbors(n); - } - - while (!allToNodesVisited()) { - bool something_was_processed = false; - std::deque not_ready; - while (!allToNodesVisited() && !to_visit_.empty()) { - const auto n = to_visit_.front(); - to_visit_.pop_front(); - - if (isVisited(n)) { - continue; - } - - auto ready_direction = isReady(n); - if (!ready_direction.has_value()) { - // To stop an infinite loop, the not-ready node is not moved - // back to the to_visit_ queue but kept in the separate - // queue. This way, if all nodes in to_visit_ are not ready, - // the queue would eventually become empty, which would then - // break the inner while loop. The something_was_processed - // flag is used to remember if there's any progress. - not_ready.emplace_back(n); - continue; - } - - // Visit this node and add its neighbors to to_visit if not - // visited yet - setVisited(n); - setPrevGroups(n, *ready_direction); - addNewNeighbors(n); - something_was_processed = true; - } - - // If nothing was processed, break out of the loop - if (!something_was_processed) { - break; - } - - // Something was processed. Redo the traversal. - to_visit_.insert(to_visit_.end(), not_ready.begin(), not_ready.end()); - } - - if (require_all_to_visited_ && !allToNodesVisited()) { - std::stringstream ss; - for (const auto& to : to_) { - if (!isVisited(to)) { - ss << " " << toString(to); - if (const ExprT* e = std::get_if(&to)) { - ss << " " << toString(*e); - } - } - } - ss << " (from: "; - for (const auto& from : from_) { - ss << " " << toString(from); - if (const ExprT* e = std::get_if(&from)) { - ss << " " << toString(*e); - } - } - ss << ")"; - ss << ", visited: ("; - for (const auto& visited : visited_) { - if (const ValT* v = std::get_if(&visited)) { - ss << " " << toString(visited); - } - } - ss << ")"; - NVF_THROW("BFS traversal could not visit some nodes: ", ss.str()); - } -#endif - } - - // Check if a node is ready to visit. If yes, return the direction - // and the prev nodes that should be visited before the given node - // is visited. - virtual std::optional>> isReady( - const NodeType& node) const { - if (const ExprT* e = std::get_if(&node)) { - return isReady(*e); - } else if (const ValT* v = std::get_if(&node)) { - return isReady(*v); - } else { - NVF_THROW(); - } - } - - // Check if an ExprT is ready to visit. Either all of its inputs - // or all of outputs must have their dependencies satisfied. If - // ready because the inputs are already visited, return - // Direction::Forward and all the input nodes. If ready because the - // outputs are ready, return Direction::Backward and all the output nodes. - virtual std::optional>> isReady( - const ExprT& expr) const { - // Either all inputs or all outputs must have been visited - decltype(auto) inputs = inputs_(expr); - if (!inputs.empty() && allowed_direction_ != Direction::Backward && - std::all_of( - inputs.begin(), inputs.end(), [&](const ValT& input) -> bool { - return isDependencySatisfied(input); - })) { - std::vector prev_nodes; - std::copy_if( - inputs.begin(), - inputs.end(), - std::back_inserter(prev_nodes), - [&](const ValT& input) -> bool { return isVisited(input); }); - return std::make_pair(Direction::Forward, prev_nodes); - } - - decltype(auto) outputs = outputs_(expr); - if (!outputs.empty() && allowed_direction_ != Direction::Forward && - std::all_of( - outputs.begin(), outputs.end(), [&](const ValT& output) -> bool { - return isDependencySatisfied(output); - })) { - std::vector prev_nodes; - std::copy_if( - outputs.begin(), - outputs.end(), - std::back_inserter(prev_nodes), - [&](const ValT& output) -> bool { return isVisited(output); }); - return std::make_pair(Direction::Backward, prev_nodes); - } - - return std::nullopt; - } - - // Check if a val is ready to visit. Either its defining or use - // expr must have its dependency satisfied. If ready because - // there's a visited defining expr, return Direction::Forward and - // the defining expr. If ready because there's a visited use expr, return - // Direction::Backward and the use expr. - virtual std::optional>> isReady( - const ValT& v) const { - // In the case of Val, requires just one def or use expr. - // Check if any use is visited - decltype(auto) uses = uses_(v); - if (!uses.empty()) { - auto it = std::find_if( - uses.begin(), uses.end(), [&](const ExprT& use_e) -> bool { - return isDependencySatisfied(use_e); - }); - if (it != uses.end()) { - return std::make_pair(Direction::Backward, std::vector{*it}); - } - } - // Check if any def is visited - decltype(auto) def = definition_(v); - if (!def.empty()) { - auto it = - std::find_if(def.begin(), def.end(), [&](const ExprT& def_e) -> bool { - return isDependencySatisfied(def_e); - }); - if (it != def.end()) { - return std::make_pair(Direction::Forward, std::vector{*it}); - } - } - - return std::nullopt; - } - - // If another node depends on a given node, check if that - // dependency is considered satisfied. If the given node is already - // visited, that should mean the dependency is satisfied. - virtual bool isDependencySatisfied(const NodeType& dependency) const { - return isVisited(dependency); - } - - // Check if a given node is already visited - virtual bool isVisited(const NodeType& node) const { - return visited_.find(node) != visited_.end(); - } - - // Mark a node as visited - virtual void setVisited(const NodeType& node) { - visited_.emplace(node); - } - - // Add new neighbors of a given node to the to_visit list - virtual void addNewNeighbors(const NodeType& node) { - auto add_to_visit_list = [&](const NodeType& n) -> void { - if (isVisited(n) || excludeFromTraversal(n)) { - return; - } - to_visit_.emplace_back(n); - }; - - if (const ExprT* e = std::get_if(&node)) { - if (allowed_direction_ == Direction::Backward || - allowed_direction_ == Direction::Undefined) { - for (const auto& v : inputs_(*e)) { - add_to_visit_list(v); - } - } - if (allowed_direction_ == Direction::Forward || - allowed_direction_ == Direction::Undefined) { - for (const auto& v : outputs_(*e)) { - add_to_visit_list(v); - } - } - } else if (const ValT* v = std::get_if(&node)) { - if (allowed_direction_ == Direction::Forward || - allowed_direction_ == Direction::Undefined) { - for (const auto& e : uses_(*v)) { - add_to_visit_list(e); - } - } - if (allowed_direction_ == Direction::Backward || - allowed_direction_ == Direction::Undefined) { - for (const auto& e : definition_(*v)) { - add_to_visit_list(e); - } - } - } else { - NVF_THROW(); - } - } - - // Check if all to_ are visited - virtual bool allToNodesVisited() const { - return std::all_of( - to_.begin(), to_.end(), [&](const NodeType& node) -> bool { - return isVisited(node); - }); - }; - - // Hook to exclude certain graph nodes. See IndexingTraversal for a - // concrete example - virtual bool excludeFromTraversal(const NodeType& node) const { - return false; - } - - protected: - const DefinitionT definition_; - const UsesT uses_; - const InputsT inputs_; - const OutputsT outputs_; - const std::vector from_; - const std::vector to_; - std::deque to_visit_; - std::unordered_set visited_; - bool require_all_to_visited_ = true; - Direction allowed_direction_ = Direction::Undefined; -}; - - -} // namespace nvfuser From 6d86f6fc990e6e5241a83c71b1d36d4d4db94e71 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 24 Feb 2025 15:20:16 -0800 Subject: [PATCH 3/5] test cleanup --- tests/cpp/test_resize.cpp | 51 ++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index a29416b8603..a0ad7bb0d77 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -5822,6 +5822,10 @@ TEST_F(ResizeTest, DoNotFuseResizeAndIndexOps) { } } +// Split-based reshape followed by a slice. The reshape is not +// cancelable. The vectorization factor based on the innermost logical +// ID of the input is not a valid factor as the fusion is scheduled +// based on the post-reshape shape. TEST_F(ResizeTest, VectorizeInnermostWithReshapeSplit) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; @@ -5841,21 +5845,28 @@ TEST_F(ResizeTest, VectorizeInnermostWithReshapeSplit) { {IrBuilder::create(0L), IrBuilder::create(shape2[1])}}); fusion.addOutput(tv3); - fusion.printMath(); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto t0 = at::randn(shape1, options); - // Not sure why, but MarkAliasesPreparePass inserts segment_set - // after reshape - preseg_passes::OptimizationPassGuard - guard(false); - - FusionExecutorCache executor_cache(std::move(fusion_ptr)); - auto outputs = executor_cache.runFusionWithInputs({t0}); - testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__); + auto outputs = scheduleAndRun(&fusion, SchedulerType::Resize, {t0}); + testValidate(&fusion, outputs.outputs, {t0}, __LINE__, __FILE__); + + // Should be vector by a factor of 2 because the resize scheduler + // only uses the innermost logical ID, and the extent of the output + // tensor is just 2. Before PR #3955, the resize scheduler + // attempted to vectorize by 4. Note that the slice op itself does + // not matter for the vectorization as the sliced ID is not involved + // in the vectorization. + EXPECT_EQ( + tv3->getLoopDomain().back()->getParallelType(), ParallelType::Vectorize); + EXPECT_EQ(tv3->getLoopDomain().back()->extent()->evaluate(), 2); } +// Merge-based reshape followed by a slice. The reshape is +// cancelable. If the output is used as the reference but the reshape +// is canceled, the valid vectorization factor should be 2. The WAR of +// PR #3955 gives up canceling any reshape that involves innermost +// logical IDs to avoid this inconsistency. TEST_F(ResizeTest, VectorizeInnermostWithReshapeMerge) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; @@ -5868,7 +5879,7 @@ TEST_F(ResizeTest, VectorizeInnermostWithReshapeMerge) { fusion.addInput(tv0); auto tv1 = sin(tv0); - // [16, 128 * 16 / 2, 2] -> [16, 128 * 16]. Cancellable reshape. + // [16, 128 * 16 / 2, 2] -> [16, 128 * 16]. Cancelable reshape. auto tv2 = reshape(tv1, shape1, shape2); auto tv3 = slice( tv2, @@ -5876,19 +5887,19 @@ TEST_F(ResizeTest, VectorizeInnermostWithReshapeMerge) { {IrBuilder::create(0L), IrBuilder::create(shape2[1])}}); fusion.addOutput(tv3); - fusion.printMath(); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto t0 = at::randn(shape1, options); - // Not sure why, but MarkAliasesPreparePass inserts segment_set - // after reshape - preseg_passes::OptimizationPassGuard - guard(false); + auto outputs = scheduleAndRun(&fusion, SchedulerType::Resize, {t0}); + testValidate(&fusion, outputs.outputs, {t0}, __LINE__, __FILE__); - FusionExecutorCache executor_cache(std::move(fusion_ptr)); - auto outputs = executor_cache.runFusionWithInputs({t0}); - testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__); + // Should be vector by a factor of 4. If the reshape were canceled, + // it should have been 2, but in this case since it involves the + // innermost logical ID of tv2, it is not canceled, thus + // vectorization by 4 should be chosen. + EXPECT_EQ( + tv3->getLoopDomain().back()->getParallelType(), ParallelType::Vectorize); + EXPECT_EQ(tv3->getLoopDomain().back()->extent()->evaluate(), 4); } } // namespace nvfuser From 7865aa87460232e4b91910a2185aef2ac9e06871 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 24 Feb 2025 15:38:20 -0800 Subject: [PATCH 4/5] cleanup --- csrc/scheduler/resize.cpp | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/csrc/scheduler/resize.cpp b/csrc/scheduler/resize.cpp index c6d434a7b69..46b2813dd39 100644 --- a/csrc/scheduler/resize.cpp +++ b/csrc/scheduler/resize.cpp @@ -234,20 +234,13 @@ std::unique_ptr ResizeScheduler::computeHeuristics( }); TensorView* ref_tv = ref_tv_entry.get()[0]; - // Before applying the vectorization split, any reshape transform of - // the largest input will be cancelled whenever possible, so the - // largest input is used as the reference of vectorization. - // auto vec_ref_tv = largest_input != nullptr ? largest_input : - // ref_tv; - auto vec_ref_tv = ref_tv; - // Only consider the innermost dimension to vectorize for now. // TODO: Consider vectorizing merged IDs, not just the innermost params->vectorization_factor = vectorize_helper::getVectorizationFactor( runtime_info, - vec_ref_tv, + ref_tv, data_cache, - (int64_t)vec_ref_tv->getLogicalDomain().size() - 1, + (int64_t)ref_tv->getLogicalDomain().size() - 1, {}); return params; From 7df8e0713911e7295b7d5db4cb0ab07e956fd949 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 24 Feb 2025 23:17:42 -0800 Subject: [PATCH 5/5] WAR for the other vec issue --- tests/cpp/test_resize.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index a0ad7bb0d77..fea1975fef5 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -4755,7 +4755,7 @@ TEST_P(ResizeSchedulerTest, SliceRotateCat) { Fusion& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); - std::vector shape({-1, 100}); + std::vector shape({-1, 128}); EnableOptionsGuard enable_options_guard; EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); @@ -4781,7 +4781,7 @@ TEST_P(ResizeSchedulerTest, SliceRotateCat) { auto tv5 = cat({tv4, tv2}, 1); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({16, 100}, options); + auto t0 = at::randn({16, 128}, options); fusion.addOutput(tv5);