diff --git a/csrc/graph_traversal.h b/csrc/graph_traversal.h new file mode 100644 index 00000000000..7bb34c395a7 --- /dev/null +++ b/csrc/graph_traversal.h @@ -0,0 +1,596 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +namespace nvfuser { + +// Find all exprs reachable from from_nodes when traversing to to_nodes. Edges +// are visitd only once, but nodes may be visited multiple times. Edges are +// always between ExprT and ValT and are directed, e.g., an edge from an +// ExprGroup to a ValGroup is differentiated from an edge from the ValGroup to +// the ExprGroup, and both of them may be visited. +// +// When there's a cycle, exprs in the cycle are also included. For +// example, given a graph like (each symbol represents an expr): +// +// A -> B -> C -> D -> E +// ^ | +// +--- F ---+ +// +// Exprs of {A_fwd, F_bwd, B_fwd, C_fwd, D_fwd, E_fwd} would be +// returened. Note that there's no guarantee of ordering, although it +// is at least partially sorted in a topological order. +// +// The overall traversal algorithm is to start from from_nodes and +// traverse edges in both directions or only in a specified +// direction. Unlike BFS, it keeps traversing even if all +// the to_nodes are reached and stops when no further progress is +// possible. At this point, we know all the reachable edges from +// from_nodes but we are only interested in that reach to_nodes. To +// find those edges, another traversal, this time from to_ndoes, is +// done to mark all visited edges that are reachable from +// to_nodes. That gives us all the edges between from_nodes and +// to_nodes. Finally, ExprPath is returned based on the exprs of the +// edges. +// +// NOTE 1: The algorithm and the implementation is based on the BFS +// class. There's likely more efficient algorithms. +// +// NOTE 2: The returned expr path is not guaranteed to be +// topologically sorted, which is not possible for cyclic graphs. +template < + typename ExprT, + typename ValT, + typename DefinitionT, + typename UsesT, + typename InputsT, + typename OutputsT> +class FindAllExprs { + public: + using ExprType = ExprT; + using ValType = ValT; + using NodeType = std::variant; + using ExprPath = std::vector>; + + // Edge represents an edge in the graph. By definition, it must be + // between an expr and a val. + struct Edge { + NodeType from; + NodeType to; + Edge(const ValT& from, const ExprT& to) : from(from), to(to) {} + Edge(const ExprT& from, const ValT& to) : from(from), to(to) {} + bool operator==(const Edge& other) const { + return from == other.from && to == other.to; + } + std::string toString() const { + std::stringstream ss; + ss << "{" << nvfuser::toString(from) << " -> " << nvfuser::toString(to) + << "}"; + return ss.str(); + } + }; + + struct EdgeHash { + std::size_t operator()(const Edge& edge) const { + return std::hash()(edge.from) ^ std::hash()(edge.to); + } + }; + + using EdgeSet = std::unordered_set; + + virtual ~FindAllExprs() = default; + + public: + FindAllExprs( + DefinitionT definition, + UsesT uses, + InputsT inputs, + OutputsT outputs, + std::vector from, + std::vector to, + bool require_all_to_visited = true, + Direction allowed_direction = Direction::Undefined) + : definition_(std::move(definition)), + uses_(std::move(uses)), + inputs_(std::move(inputs)), + outputs_(std::move(outputs)), + from_nodes_(std::move(from)), + to_nodes_(std::move(to)), + require_all_to_visited_(require_all_to_visited), + allowed_direction_(allowed_direction) {} + + virtual void traverseAllEdges() { + std::deque edges_to_visit; + + for (const auto& from_node : from_nodes_) { + if (const ValT* from_val = std::get_if(&from_node)) { + for (const auto& use_expr : uses_(*from_val)) { + Edge edge(*from_val, use_expr); + setVisited(edge); + for (const auto& next_edge : + getConsumerEdges(edge, allowed_direction_)) { + edges_to_visit.push_back(next_edge); + } + } + for (const auto& def_expr : definition_(*from_val)) { + Edge edge(*from_val, def_expr); + setVisited(edge); + for (const auto& next_edge : + getConsumerEdges(edge, allowed_direction_)) { + edges_to_visit.push_back(next_edge); + } + } + } else { + NVF_THROW( + "Traversal from nodes are assumed to be all Vals but found: ", + toString(from_node)); + } + } + + bool something_was_processed = true; + while (something_was_processed) { + std::deque not_ready; + something_was_processed = false; + + while (!edges_to_visit.empty()) { + const auto edge_to_visit = edges_to_visit.front(); + edges_to_visit.pop_front(); + + // Don't visit edges multiple times even when traversing all paths + if (isVisited(edge_to_visit)) { + continue; + } + + auto prev_edges = isReady(edge_to_visit); + if (prev_edges.empty()) { + // To stop an infinite loop, the not-ready node is not moved + // back to the to_visit_ queue but kept in the separate + // queue. This way, if all nodes in to_visit_ are not ready, + // the queue would eventually become empty, which would then + // break the inner while loop. The something_was_processed + // flag is used to remember if there's any progress. + not_ready.emplace_back(edge_to_visit); + continue; + } + + setVisited(edge_to_visit); + for (const auto& next_edge : + getConsumerEdges(edge_to_visit, allowed_direction_)) { + edges_to_visit.push_back(next_edge); + } + something_was_processed = true; + } + + // Something was processed. Redo the traversal. + edges_to_visit.insert( + edges_to_visit.end(), not_ready.begin(), not_ready.end()); + } + + if (require_all_to_visited_ && !allToNodesVisited()) { + auto visited_nodes = getVisitedNodes(); + std::stringstream ss; + for (const auto& to : to_nodes_) { + if (!visited_nodes.count(to)) { + ss << " " << toString(to); + } + } + ss << " (from: "; + for (const auto& from : from_nodes_) { + ss << " " << toString(from); + } + ss << ")"; + ss << ", visited: ("; + for (const auto& visited : visited_nodes) { + if (const ValT* v = std::get_if(&visited)) { + ss << " " << toString(visited); + } + } + ss << ")"; + NVF_THROW("BFS traversal could not visit some nodes: ", ss.str()); + } + } + + // Check if a node is ready to visit. If yes, return the direction + // and the prev nodes that should be visited before the given node + // is visited. + virtual std::vector isReady(const Edge& edge) const { + Direction dir = getDirection(edge); + + // If a direction is specified, only that direction of edges are + // allowed. + if ((dir == Direction::Forward && + allowed_direction_ == Direction::Backward) || + (dir == Direction::Backward && + allowed_direction_ == Direction::Forward)) { + return {}; + } + + if (const ExprT* e = std::get_if(&(edge.from))) { + return isReady(*e, std::get(edge.to), dir); + } else if (const ValT* v = std::get_if(&(edge.from))) { + return isReady(*v, std::get(edge.to), dir); + } else { + NVF_THROW(); + } + } + + // Check if an edge from an expr to a val is ready to visit. If this + // is a forward edge, i.e., the val is an output of the expr, the + // edge is ready to visit as long as all the inputs of the expr are + // visited. If it's a backward edge, i.e., the val is an input of + // the expr, it's ready if all of the outputs are visited. If ready, + // the edges that this edge depends on are returned. For example, in + // the case of a forward edge, all of the edges to from_expr are + // returned. + virtual std::vector isReady( + const ExprT& from_expr, + const ValT& to_val, + Direction dir) const { + if (dir == Direction::Forward) { + decltype(auto) inputs = inputs_(from_expr); + if (std::all_of( + inputs.begin(), inputs.end(), [&](const ValT& input) -> bool { + return isVisited(Edge(input, from_expr)); + })) { + std::vector prev_edges; + for (const ValT& input : inputs) { + prev_edges.push_back(Edge(input, from_expr)); + } + return prev_edges; + } + } else if (dir == Direction::Backward) { + decltype(auto) outputs = outputs_(from_expr); + if (std::all_of( + outputs.begin(), outputs.end(), [&](const ValT& output) -> bool { + return isVisited(Edge(output, from_expr)); + })) { + std::vector prev_edges; + for (const ValT& output : outputs) { + prev_edges.push_back(Edge(output, from_expr)); + } + return prev_edges; + } + } + + return {}; + } + + // Check if an edge from a val to an expr is ready to visit. In the + // case of a val, it is ready to visit as long as there's at least + // one def or use expr that has been already visited. However, since + // this is an edge to an expr, the edge from the same expr to this + // val does not make this edge ready to visit. For example, even if + // a merge producing i0 is visited, it should not automatically mean + // the edge from i0 to the merge expr is ready to visit. Othewise, + // the traversal would just move back and forth. + virtual std::vector isReady( + const ValT& from_val, + const ExprT& to_expr, + Direction dir) const { + std::vector prev_edges; + + // Check if any def is visited + decltype(auto) def = definition_(from_val); + if (!def.empty()) { + for (const ExprT& def_e : def) { + if (def_e != to_expr && isVisited(Edge(def_e, from_val))) { + prev_edges.emplace_back(Edge(def_e, from_val)); + } + } + } + + decltype(auto) uses = uses_(from_val); + for (const ExprT& use_e : uses) { + if (use_e != to_expr && isVisited(Edge(use_e, from_val))) { + prev_edges.emplace_back(Edge(use_e, from_val)); + } + } + + return prev_edges; + } + + // Check if a given node is already visited + virtual bool isVisited(const Edge& edge) const { + return visited_edges_.find(edge) != visited_edges_.end(); + } + + virtual void setVisited(const Edge& edge) { + if (visited_edges_.emplace(edge).second) { + partially_ordered_visited_edges_.push_back(edge); + } + } + + // Get edges that are consumers or producers of a given edge. A + // consumer edge of edge A->B is an edge that has node B as its from + // node. A producer edge is an edge that has node A as its to node. + virtual std::vector getConsumerOrProducerEdges( + const Edge& edge, + bool is_consumer, + Direction allowed_direction = Direction::Undefined) const { + std::vector neighbor_edges; + + auto add_to_neighbor_list = [&](const auto& from, const auto& to) -> void { + Edge neighbor_edge(from, to); + + if (edge == neighbor_edge || + // Don't traverse back + (edge.from == neighbor_edge.to && edge.to == neighbor_edge.from)) { + return; + } + + if (excludeFromTraversal(neighbor_edge)) { + return; + } + + auto neighbor_edge_dir = getDirection(neighbor_edge); + if ((allowed_direction == Direction::Forward && + neighbor_edge_dir == Direction::Backward) || + (allowed_direction == Direction::Backward && + neighbor_edge_dir == Direction::Forward)) { + return; + } + + neighbor_edges.push_back(neighbor_edge); + }; + + Direction edge_dir = getDirection(edge); + NVF_ERROR( + edge_dir == Direction::Forward || edge_dir == Direction::Backward); + + const auto& node = is_consumer ? edge.to : edge.from; + + // Since the direction is forward, this edge is + // Consumer edges are those that start from the e expr. Since + // the direction is Forward, When grabbing consumer edges, If the node is + // the to of the edge, the edge is from an input Val to its use Expr, so + // traverse from the use Expr to its outputs. If the node is the from of the + // edge, the edge is from a defining expr to one of its outputs, in that + // case grab edges of the inputs of the expr. + + if (const ExprT* e = std::get_if(&node)) { + // The from node must be a Val. + + // In the case of Expr, only consider edges of the same + // direction + if (edge_dir == Direction::Forward) { + if (is_consumer) { + // Grab consumer edges of the forward edge to the expr. The + // edge represents a use expr of the from val. Consumers are + // forward edges from the expr to its outputs. + for (const auto& v : outputs_(*e)) { + add_to_neighbor_list(*e, v); + } + } else { + // Grab producer edges of the forward edge from the expr. The + // edge represents a defining expr of the to val. Producers + // are forward edges to the defining expr from its inputs. + for (const auto& v : inputs_(*e)) { + add_to_neighbor_list(v, *e); + } + } + } else if (edge_dir == Direction::Backward) { + if (is_consumer) { + // Grab consumer edges of the backward edge to the expr. The + // edge represents a defining expr of the from val. Consumers + // are backward edges from the defining expr to its inputs. + for (const auto& v : inputs_(*e)) { + add_to_neighbor_list(*e, v); + } + } else { + // Grab producer edges of the backward edge from the expr. The + // edge represents a use expr of the from val. Produces + // are backward edges to the use expr expr from its outputs. + for (const auto& v : outputs_(*e)) { + add_to_neighbor_list(v, *e); + } + } + } + } else if (const ValT* v = std::get_if(&node)) { + // The from node must be an Expr. + + // In the case of Val, no matter what direction this node is, it + // should be valid to traverse both directions. Just don't + // traverse back to the same node. + + for (const auto& e : uses_(*v)) { + if (is_consumer) { + // Uses of v are forward consumer edges of the edge to val v + add_to_neighbor_list(*v, e); + } else { + // Uses of v are backward producer edges of the edge from val v + add_to_neighbor_list(e, *v); + } + } + + for (const auto& e : definition_(*v)) { + if (is_consumer) { + // Defs of v are backward consumer edges of the edge to val v + add_to_neighbor_list(*v, e); + } else { + // Defs of v are forward producer edges of the edge from val v + add_to_neighbor_list(e, *v); + } + } + } + + return neighbor_edges; + } + + // Get edges that should be traversed from the to node of a given edge + virtual std::vector getConsumerEdges( + const Edge& edge, + Direction allowed_direction = Direction::Undefined) const { + return getConsumerOrProducerEdges( + edge, /*is_consumer=*/true, allowed_direction); + } + + // Get edges that should be traversed before the from node of a given edge + virtual std::vector getProducerEdges( + const Edge& edge, + Direction allowed_direction = Direction::Undefined) const { + return getConsumerOrProducerEdges( + edge, /*is_consumer=*/false, allowed_direction); + } + + // Check if all to_ are visited + virtual bool allToNodesVisited() const { + auto visited_nodes = getVisitedNodes(); + return std::all_of( + to_nodes_.begin(), to_nodes_.end(), [&](const NodeType& node) -> bool { + return visited_nodes.count(node); + }); + }; + + // Hook to exclude certain graph edges. + virtual bool excludeFromTraversal(const Edge& edge) const { + return false; + } + + // If an edge is from a val to its use expr, it's a forward + // edge. Similarly, it's also a forward edge if it's an expr to one + // of its outputs. Otherwise, it's a backward edge. + Direction getDirection(const Edge& edge) const { + if (const ExprT* from_expr = std::get_if(&edge.from)) { + const ValT& to_val = std::get(edge.to); + decltype(auto) inputs = inputs_(*from_expr); + if (std::find(inputs.begin(), inputs.end(), to_val) != inputs.end()) { + return Direction::Backward; + } + decltype(auto) outputs = outputs_(*from_expr); + if (std::find(outputs.begin(), outputs.end(), to_val) != outputs.end()) { + return Direction::Forward; + } + NVF_THROW(); + } else if (const ValT* from_val = std::get_if(&edge.from)) { + const ExprT& to_expr = std::get(edge.to); + decltype(auto) inputs = inputs_(to_expr); + if (std::find(inputs.begin(), inputs.end(), *from_val) != inputs.end()) { + return Direction::Forward; + } + decltype(auto) outputs = outputs_(to_expr); + if (std::find(outputs.begin(), outputs.end(), *from_val) != + outputs.end()) { + return Direction::Backward; + } + NVF_THROW(); + } else { + NVF_THROW(); + } + } + + virtual std::unordered_set getVisitedNodes() const { + std::unordered_set visited_nodes; + visited_nodes.insert(from_nodes_.begin(), from_nodes_.end()); + for (const auto& visited_edge : visited_edges_) { + visited_nodes.emplace(visited_edge.from); + visited_nodes.emplace(visited_edge.to); + } + return visited_nodes; + } + + // Grab all visited edges that are reachable from from_nodes and + // to_nodes. traverseAllEdges must have been completed. + virtual EdgeSet getUsedEdges() const { + NVF_ERROR( + !require_all_to_visited_ || allToNodesVisited(), + "Traveral is either not done or failed"); + + // Traverse back from to_ nodes to from_ nodes by traversing + // through visted edges + std::deque to_visit; + + // Gather all visited edges to the to_ nodes. These edges are used + // as initial edges for the traversal below + for (const NodeType& to_node : to_nodes_) { + if (const ValT* to_val = std::get_if(&to_node)) { + for (const ExprT& use_expr : uses_(*to_val)) { + Edge e{use_expr, *to_val}; + if (isVisited(e)) { + to_visit.emplace_back(e); + } + } + for (const ExprT& def_expr : definition_(*to_val)) { + Edge e{def_expr, *to_val}; + if (isVisited(e)) { + to_visit.emplace_back(e); + } + } + } else { + NVF_THROW( + "Traversal to nodes are assumed to be all Vals but found: ", + toString(to_node)); + } + } + + EdgeSet used_edges; + + while (!to_visit.empty()) { + const auto edge_to_visit = to_visit.front(); + to_visit.pop_front(); + + if (used_edges.count(edge_to_visit)) { + continue; + } + + auto producer_edges = getProducerEdges(edge_to_visit); + for (const Edge& producer_edge : producer_edges) { + if (isVisited(producer_edge)) { + to_visit.emplace_back(producer_edge); + } + } + + used_edges.insert(edge_to_visit); + } + + return used_edges; + } + + // Return ExprPath consisting of all exprs appearing between + // from_nodes and to_ndoes. The exprs are partially topologically + // sorted, but not completely. The ordering should be deterministic, + // but do not assume any particular ordering. + virtual std::pair getPartiallyOrderedExprs() const { + const auto used_edges = getUsedEdges(); + + VectorOfUniqueEntries> expr_path; + + for (const Edge& ordered_visited_edge : partially_ordered_visited_edges_) { + if (!used_edges.count(ordered_visited_edge)) { + continue; + } + + Direction edge_dir = getDirection(ordered_visited_edge); + + // Append the expr of this edge + const ExprT& expr = + std::get_if(&(ordered_visited_edge.from)) != nullptr + ? std::get(ordered_visited_edge.from) + : std::get(ordered_visited_edge.to); + expr_path.pushBack(std::make_pair(expr, edge_dir)); + } + + return std::make_pair(expr_path.vector(), allToNodesVisited()); + } + + protected: + const DefinitionT definition_; + const UsesT uses_; + const InputsT inputs_; + const OutputsT outputs_; + const std::vector from_nodes_; + const std::vector to_nodes_; + bool require_all_to_visited_ = true; + Direction allowed_direction_ = Direction::Undefined; + + EdgeSet visited_edges_; + std::vector partially_ordered_visited_edges_; +}; + +} // namespace nvfuser diff --git a/csrc/scheduler/compile_time_info.h b/csrc/scheduler/compile_time_info.h index 3436bcd70eb..fa0f5199067 100644 --- a/csrc/scheduler/compile_time_info.h +++ b/csrc/scheduler/compile_time_info.h @@ -38,6 +38,7 @@ enum class CompileTimeEntryType { VECTORIZABLE_INPUTS_AND_OUTPUTS, INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS, TV_TO_CONTIG_INNER_SIZE_MAPS, + RESIZE_VECTORIZATION_FACTORS, UNROLLABLE_INPUTS_AND_OUTPUTS, REDUCTION_TVS, PERSISTENT_BUFFER_INFO, @@ -106,6 +107,15 @@ class TvToContigInnerSizeMaps { CompileTimeEntryType::TV_TO_CONTIG_INNER_SIZE_MAPS; }; +//! Stores the scalar vals that a vectorization factor must be able to +//! divide evenly +class ResizeVectorizationFactors { + public: + using DataType = std::unordered_set; + static const CompileTimeEntryType EntryType = + CompileTimeEntryType::RESIZE_VECTORIZATION_FACTORS; +}; + //! Entry type definition class for `INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS`, //! stores the fusion's inputs and outputs grouped by inner most dimension. class InputsOutputsInnerDimGroups { diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index c1f1b33720d..33316c9480e 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -223,6 +223,8 @@ template class HeuristicDataCacheEntry< HeuristicCompileTime::VectorizableInputsAndOutputs>; template class HeuristicDataCacheEntry< HeuristicCompileTime::TvToContigInnerSizeMaps>; +template class HeuristicDataCacheEntry< + HeuristicCompileTime::ResizeVectorizationFactors>; template class HeuristicDataCacheEntry< HeuristicCompileTime::InputsOutputsInnerDimGroups>; template class HeuristicDataCacheEntry< diff --git a/csrc/scheduler/resize.cpp b/csrc/scheduler/resize.cpp index 46b2813dd39..171ef740e40 100644 --- a/csrc/scheduler/resize.cpp +++ b/csrc/scheduler/resize.cpp @@ -100,7 +100,7 @@ bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) { IdModel id_model(fusion, /*build_graphs=*/false); const auto& broadcast_graph = id_model.buildBroadcastGraph(); - auto resize_tensor_ops = ir_utils::getOpsOfType(fusion); + auto resize_tensor_ops = scheduler_tools::getResizeBasedOps(fusion); // Slicing of or to a broadcast ID is not allowed yet. for (auto resize_tensor_op : resize_tensor_ops) { diff --git a/csrc/scheduler/tools/resize_utils.cpp b/csrc/scheduler/tools/resize_utils.cpp index f281b40fe79..4abd4b3a856 100644 --- a/csrc/scheduler/tools/resize_utils.cpp +++ b/csrc/scheduler/tools/resize_utils.cpp @@ -26,6 +26,10 @@ bool hasResizeBasedOps(Fusion* fusion) { return ir_utils::hasOpsOfType(fusion); } +std::vector getResizeBasedOps(Fusion* fusion) { + return ir_utils::getOpsOfType(fusion); +} + void propagateResizeToInputs(Expr* resize_tensor_op) { NVF_ERROR( resize_tensor_op->isA() || resize_tensor_op->isA(), diff --git a/csrc/scheduler/tools/resize_utils.h b/csrc/scheduler/tools/resize_utils.h index 99e03153a37..ec03e983fc4 100644 --- a/csrc/scheduler/tools/resize_utils.h +++ b/csrc/scheduler/tools/resize_utils.h @@ -20,6 +20,8 @@ bool isResizeBasedOp(Expr* expr); bool hasResizeBasedOps(Fusion* fusion); +std::vector getResizeBasedOps(Fusion* fusion); + // For a given resize-based tensor op such as SliceOp and PadOp, make the loop // domain of each dependent producer tensor exact-mapped by propagating // the iter-domain ops of the output tensor of the given op. Note that diff --git a/csrc/scheduler/vectorize_helper.cpp b/csrc/scheduler/vectorize_helper.cpp index ceebca598e7..d4e52b21076 100644 --- a/csrc/scheduler/vectorize_helper.cpp +++ b/csrc/scheduler/vectorize_helper.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -843,6 +844,96 @@ std::vector> getTvToContigInnerSizeMapsOf( return mappers; } +// This is a WAR for vectorizing through resized iter domains. The +// spanning tree based analysis is not guaranteed to take all resize +// ops into considerations (issue +// https://github.com/NVIDIA/Fuser/issues/3640). To workaround the +// limitation, grab all factors that must be divisible by a +// vectorization factors. +std::unordered_set getResizeVectorizationFactors( + TensorView* reference_tv, + int64_t break_point) { + Fusion* fusion = reference_tv->fusion(); + std::unordered_set factors; + const auto resize_based_ops = scheduler_tools::getResizeBasedOps(fusion); + + if (resize_based_ops.empty()) { + return factors; + } + + IdModel id_model(reference_tv->fusion()); + const auto& graph = id_model.buildExactGraph(); + + const auto ref_groups = graph.toGroups(reference_tv->getLogicalDomain()); + + // For each of resize-based tensor ops, find all resize ops + // that exist between the vectorized reference IDs and the output + // tensor. + for (auto resize_based_op : resize_based_ops) { + auto resize_out = resize_based_op->output(0)->as(); + NVF_ERROR( + resize_out->hasRoot(), "Unexpected op: ", resize_based_op->toString()); + // getAllExprGroupsBetween finds exprs between IDs. To make sure + // the the resize op of this resize_based_op tensor op is found, + // use both the root and logical domains as the traversal targets. + ValGroups resize_inp_out; + resize_inp_out.pushBack(graph.toGroups(resize_out->getRootDomain())); + resize_inp_out.pushBack(graph.toGroups(resize_out->getLogicalDomain())); + + auto expr_path = getAllExprGroupsBetween( + graph, + ref_groups, + resize_inp_out, + /*require_all_to_visited=*/false) + .first; + + ValGroups vectorized_groups; + for (auto it = reference_tv->getLogicalDomain().begin() + break_point; + it != reference_tv->getLogicalDomain().end(); + ++it) { + vectorized_groups.pushBack(graph.toGroup(*it)); + } + + // Find all resize exprs that appear in expr_path and depend on + // vectorized_groups. Since expr_path is not guaranteed to be + // topologically sorted, need to loop through the path until + // converged. + + bool something_has_changed = true; + while (something_has_changed) { + something_has_changed = false; + for (const auto& [expr_g, dir] : expr_path) { + const auto inputs = getInputsOfExprGroup(graph, expr_g, dir); + if (std::none_of( + inputs.begin(), inputs.end(), [&](const ValGroup& inp) { + return vectorized_groups.has(inp); + })) { + continue; + } + + if (vectorized_groups.pushBack( + getOutputsOfExprGroup(graph, expr_g, dir))) { + something_has_changed = true; + } + + auto resize = dynamic_cast(expr_g->front()); + if (resize == nullptr) { + continue; + } + + // These three vals need to be divisible + factors.emplace(resize->leftExpand()); + factors.emplace(resize->rightExpand()); + factors.emplace( + dir == Direction::Forward ? resize->out()->extent() + : resize->in()->extent()); + } + } + } + + return factors; +} + } // namespace int64_t getVectorizationFactor( @@ -881,6 +972,15 @@ int64_t getVectorizationFactor( return 1; } + auto resize_factors_entry = + HeuristicDataCacheEntry( + data_cache, [&reference_tv, &break_point]() { + return std::make_unique>( + getResizeVectorizationFactors(reference_tv, break_point)); + }); + + const auto& resize_factors = resize_factors_entry.get(); + int64_t max_vec_size = SchedulerRuntimeInfo::max_alignment_size_in_byte; const auto& tv_to_inner_size_map = vectorize_maps_entry.get().at(break_point); @@ -920,6 +1020,19 @@ int64_t getVectorizationFactor( max_vec_size); } + // This is a WAR for vectorization through resize as the spanning + // tree based traversal is not guaranteed to reflect all resize ops + // that may affect vectorization. This is a safe but conservative + // analysis since it should only be necessary for innermost IDs. + for (const auto resize_factor : resize_factors) { + auto inferred_val = + runtime_info.expressionEvaluator().evaluate(resize_factor); + if (!inferred_val.hasValue()) { + return 1; + } + max_vec_size = std::gcd(max_vec_size, inferred_val.as()); + } + return max_vec_size; } diff --git a/csrc/val_graph_visitor.cpp b/csrc/val_graph_visitor.cpp index 8110f3e12bd..92376ca0b33 100644 --- a/csrc/val_graph_visitor.cpp +++ b/csrc/val_graph_visitor.cpp @@ -5,11 +5,9 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include - +#include #include - -#include +#include namespace nvfuser { @@ -245,4 +243,30 @@ bool isCyclic(const ValGraph& graph) { return ValGraphCycleDetector(graph).cycle_detected_; } +std::pair getAllExprGroupsBetween( + const ValGraph& graph, + const ValGroups& from, + const ValGroups& to, + bool require_all_to_visited, + Direction allowed_direction) { + FindAllExprs< + ExprGroup, + ValGroup, + ValGraphDefinitions, + ValGraphUses, + ValGraphInputs, + ValGraphOutputs> + finder( + ValGraphDefinitions{graph}, + ValGraphUses{graph}, + ValGraphInputs{graph}, + ValGraphOutputs{graph}, + {from.vector().begin(), from.vector().end()}, + {to.vector().begin(), to.vector().end()}, + require_all_to_visited, + allowed_direction); + finder.traverseAllEdges(); + return finder.getPartiallyOrderedExprs(); +} + } // namespace nvfuser diff --git a/csrc/val_graph_visitor.h b/csrc/val_graph_visitor.h index 40e1a9b14bf..c7830a3de86 100644 --- a/csrc/val_graph_visitor.h +++ b/csrc/val_graph_visitor.h @@ -197,6 +197,8 @@ struct GetValType { using type = ValGroup; }; +using ExprGroupPath = std::vector>; + class ValGraphBFS : public BFS< ExprGroup, ValGroup, @@ -292,4 +294,13 @@ inline std::vector getOutputsOfExprGroup( expr, dir, ValGraphInputs(graph), ValGraphOutputs(graph)); } +// Grab all ExprGroups between to sets of ValGroups. ExprGroups are +// not guaranteed to be topologically sorted. +std::pair getAllExprGroupsBetween( + const ValGraph& graph, + const ValGroups& from, + const ValGroups& to, + bool require_all_to_visited = true, + Direction allowed_direction = Direction::Undefined); + } // namespace nvfuser diff --git a/tests/cpp/test_bfs.cpp b/tests/cpp/test_bfs.cpp index be9063eca9c..296d56db3f4 100644 --- a/tests/cpp/test_bfs.cpp +++ b/tests/cpp/test_bfs.cpp @@ -16,9 +16,6 @@ #include #include -#include -#include - namespace nvfuser { using BFSTest = NVFuserTest; @@ -577,4 +574,361 @@ TEST_F(BFSTest, IRBFSPermissiveTraversal2) { .second); } +using FindAllExprsTest = NVFuserTest; + +#define VALIDATE_EXPR_PATH(actual, ref) \ + do { \ + ASSERT_EQ(actual.size(), ref.size()); \ + for (const auto i : c10::irange(actual.size())) { \ + EXPECT_EQ(actual.at(i).first, ref.at(i).first) \ + << "Mismathed expr at " << i \ + << ". Expected: " << nvfuser::toString(ref.at(i).first) \ + << ". Actual: " << nvfuser::toString(actual.at(i).first); \ + EXPECT_EQ(actual.at(i).second, ref.at(i).second) \ + << "Mismathed direction at " << i \ + << ". Expected: " << ref.at(i).second \ + << ". Actual: " << actual.at(i).second; \ + } \ + } while (0); + +// Traversing backward and then forward +TEST_F(FindAllExprsTest, Test1) { + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + auto tv1 = set(tv0); + fusion.addOutput(tv1); + + // tv0: [i0, i1] + // tv1: [i0, i1] + + // Use different merge orderings + tv0->merge(0, 1)->split(0, 4); + tv1->merge(1, 0)->split(0, 4); + + // tv0: [i0*i1/4, 4] + // tv1: [i1*i0/4, 4] + + IdModel id_model(&fusion); + const ValGraph& graph = id_model.buildExactGraph(); + + ValGroups tv0_loop_groups = graph.toGroups(tv0->getLoopDomain()); + ValGroups tv1_loop_groups = graph.toGroups(tv1->getLoopDomain()); + + auto result = + getAllExprGroupsBetween(graph, tv0_loop_groups, tv1_loop_groups).first; + + ExprGroupPath reference_path{ + {graph.toGroup(tv0->getLoopDomain().at(0)->definition()), + Direction::Backward}, + {graph.toGroup( + tv0->getLoopDomain().at(0)->definition()->input(0)->definition()), + Direction::Backward}, + {graph.toGroup( + tv1->getLoopDomain().at(0)->definition()->input(0)->definition()), + Direction::Forward}, + {graph.toGroup(tv1->getLoopDomain().at(0)->definition()), + Direction::Forward}, + }; + + VALIDATE_EXPR_PATH(result, reference_path); +} + +// Traversing a cyclic graph +TEST_F(FindAllExprsTest, Test2) { + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + + // Create an ID graph with two cycles + auto tv0 = makeConcreteTensor({10}); + fusion.addInput(tv0); + // One cycle from tv0 to tv1 and then tv2 + auto tv1 = reshape(tv0, {10}, {2, 5}); + auto tv2 = reshape(tv1, {2, 5}, {10}); + // Another cycle from tv0 to tv3 and then tv4 + auto tv3 = reshape(tv0, {10}, {5, 2}); + auto tv4 = reshape(tv3, {5, 2}, {10}); + // Merge tv0, tv2 and tv4 to form cycles + auto tv5 = add(tv0, tv2); + auto tv6 = add(tv5, tv4); + fusion.addOutput(tv6); + + // Effectiely, the ID graph looks like: + // + // {tv0, tv2, tv4, tv5, tv6} + // ^ ^ + // | +---------------> {tv1} + // | + // +----------------> {tv3} + + IdModel id_model(&fusion); + const ValGraph& graph = id_model.buildExactGraph(); + + ValGroups tv0_loop_groups = graph.toGroups(tv0->getLoopDomain()); + + // Forward traversal from the tv0 groups. Should visit both tv1 and + // tv3 + { + auto result = getAllExprGroupsBetween( + graph, + tv0_loop_groups, + tv0_loop_groups, + /*require_all_to_visited=*/true, + Direction::Forward) + .first; + + ExprGroupPath reference_path{ + {graph.toGroup(tv3->getLogicalDomain().at(0)->definition()), + Direction::Forward}, + {graph.toGroup(tv1->getLogicalDomain().at(0)->definition()), + Direction::Forward}, + {graph.toGroup(tv4->getLogicalDomain().at(0)->definition()), + Direction::Forward}, + {graph.toGroup(tv2->getLogicalDomain().at(0)->definition()), + Direction::Forward}}; + + VALIDATE_EXPR_PATH(result, reference_path); + } + + // Back traversal from the tv0 groups. Should visit both tv1 and tv3 + { + auto result = getAllExprGroupsBetween( + graph, + tv0_loop_groups, + tv0_loop_groups, + /*require_all_to_visited=*/true, + Direction::Backward) + .first; + + ExprGroupPath reference_path{ + {graph.toGroup(tv2->getLogicalDomain().at(0)->definition()), + Direction::Backward}, + {graph.toGroup(tv4->getLogicalDomain().at(0)->definition()), + Direction::Backward}, + {graph.toGroup(tv1->getLogicalDomain().at(0)->definition()), + Direction::Backward}, + {graph.toGroup(tv3->getLogicalDomain().at(0)->definition()), + Direction::Backward}}; + + VALIDATE_EXPR_PATH(result, reference_path); + } + + // Forward and backward traversal from the tv0 groups. Should visit + // both tv1 and tv3 + { + auto result = + getAllExprGroupsBetween(graph, tv0_loop_groups, tv0_loop_groups).first; + + ExprGroupPath reference_path{ + {graph.toGroup(tv3->getLogicalDomain().at(0)->definition()), + Direction::Forward}, + {graph.toGroup(tv1->getLogicalDomain().at(0)->definition()), + Direction::Forward}, + {graph.toGroup(tv2->getLogicalDomain().at(0)->definition()), + Direction::Backward}, + {graph.toGroup(tv4->getLogicalDomain().at(0)->definition()), + Direction::Backward}, + {graph.toGroup(tv4->getLogicalDomain().at(0)->definition()), + Direction::Forward}, + {graph.toGroup(tv2->getLogicalDomain().at(0)->definition()), + Direction::Forward}, + {graph.toGroup(tv1->getLogicalDomain().at(0)->definition()), + Direction::Backward}, + {graph.toGroup(tv3->getLogicalDomain().at(0)->definition()), + Direction::Backward}}; + + VALIDATE_EXPR_PATH(result, reference_path); + } +} + +// Testing with a graph structure of +// +// tv0 -> tv1,tv5,tv6 -> tv2 -> tv3 -> tv4 +// ^ | +// +-----------------+ +// +// Each edge corresponds to an expr. +TEST_F(FindAllExprsTest, Test3) { + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + + auto tv0 = makeConcreteTensor({10}); + fusion.addInput(tv0); + auto tv1 = + slice(tv0, {{fusion.oneVal(), tv0->getLogicalDomain().at(0)->extent()}}); + auto tv2 = + slice(tv1, {{fusion.oneVal(), tv1->getLogicalDomain().at(0)->extent()}}); + auto tv3 = + slice(tv2, {{fusion.oneVal(), tv2->getLogicalDomain().at(0)->extent()}}); + auto tv4 = + slice(tv3, {{fusion.oneVal(), tv3->getLogicalDomain().at(0)->extent()}}); + fusion.addOutput(tv4); + auto tv5 = pad(tv3, {IrBuilder::create(2L), fusion.zeroVal()}); + auto tv6 = add(tv1, tv5); + fusion.addOutput(tv6); + + IdModel id_model(&fusion); + const ValGraph& graph = id_model.buildExactGraph(); + + ValGroups tv0_logical_groups = graph.toGroups(tv0->getLogicalDomain()); + ValGroups tv4_logical_groups = graph.toGroups(tv4->getLogicalDomain()); + + // Forward traversal from tv0. + { + auto result = getAllExprGroupsBetween( + graph, + tv0_logical_groups, + tv4_logical_groups, + /*require_all_to_visited=*/true, + Direction::Forward) + .first; + ExprGroupPath reference_path{ + {graph.toGroup(tv1->getLogicalDomain().at(0)->definition()), + Direction::Forward}, + {graph.toGroup(tv2->getLogicalDomain().at(0)->definition()), + Direction::Forward}, + {graph.toGroup(tv3->getLogicalDomain().at(0)->definition()), + Direction::Forward}, + {graph.toGroup(tv5->getLogicalDomain().at(0)->definition()), + Direction::Forward}, + {graph.toGroup(tv4->getLogicalDomain().at(0)->definition()), + Direction::Forward}}; + + VALIDATE_EXPR_PATH(result, reference_path); + } + + // Backward traversal from tv4. + { + auto result = getAllExprGroupsBetween( + graph, + tv4_logical_groups, + tv0_logical_groups, + /*require_all_to_visited=*/true, + Direction::Backward) + .first; + ExprGroupPath reference_path{ + {graph.toGroup(tv4->getLogicalDomain().at(0)->definition()), + Direction::Backward}, + {graph.toGroup(tv3->getLogicalDomain().at(0)->definition()), + Direction::Backward}, + {graph.toGroup(tv2->getLogicalDomain().at(0)->definition()), + Direction::Backward}, + {graph.toGroup(tv1->getLogicalDomain().at(0)->definition()), + Direction::Backward}, + {graph.toGroup(tv5->getLogicalDomain().at(0)->definition()), + Direction::Backward}}; + + VALIDATE_EXPR_PATH(result, reference_path); + } + + { + auto result = getAllExprGroupsBetween( + graph, + tv0_logical_groups, + tv4_logical_groups, + /*require_all_to_visited=*/true, + Direction::Undefined) + .first; + ExprGroupPath reference_path{ + {graph.toGroup(tv1->getLogicalDomain().at(0)->definition()), + Direction::Forward}, + {graph.toGroup(tv2->getLogicalDomain().at(0)->definition()), + Direction::Forward}, + {graph.toGroup(tv5->getLogicalDomain().at(0)->definition()), + Direction::Backward}, + {graph.toGroup(tv3->getLogicalDomain().at(0)->definition()), + Direction::Forward}, + {graph.toGroup(tv4->getLogicalDomain().at(0)->definition()), + Direction::Forward}, + {graph.toGroup(tv3->getLogicalDomain().at(0)->definition()), + Direction::Backward}, + {graph.toGroup(tv5->getLogicalDomain().at(0)->definition()), + Direction::Forward}, + {graph.toGroup(tv2->getLogicalDomain().at(0)->definition()), + Direction::Backward}}; + + VALIDATE_EXPR_PATH(result, reference_path); + } +} + +// Test with the ROPE rotation pattern +TEST_F(FindAllExprsTest, Rotation) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + std::vector shape({16, 100}); + + EnableOptionsGuard enable_options_guard; + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + auto tv0 = makeConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = sin(tv0); + + auto tv2 = slice( + tv1, + {{fusion.zeroVal(), tv1->getLogicalDomain().at(0)->extent()}, + {fusion.zeroVal(), IrBuilder::create(shape[1] / 2)}}); + + auto tv3 = sin(tv0); + + auto tv4 = slice( + tv3, + {{fusion.zeroVal(), tv3->getLogicalDomain().at(0)->extent()}, + {IrBuilder::create(shape[1] / 2), + IrBuilder::create(shape[1])}}); + + auto tv5 = cat({tv4, tv2}, 1); + + auto tv6 = add(tv0, tv5); + + fusion.addOutput(tv6); + + IdModel id_model(&fusion); + const ValGraph& graph = id_model.buildExactGraph(); + + // Traversal from tv6 to tv0 should include all exprs + ValGroups tv0_logical_groups = graph.toGroups(tv0->getLogicalDomain()); + ValGroups tv6_logical_groups = graph.toGroups(tv6->getLogicalDomain()); + + { + auto result = getAllExprGroupsBetween( + graph, + tv6_logical_groups, + tv0_logical_groups, + /*require_all_to_visited=*/true, + Direction::Undefined) + .first; + auto tv4_pad = tv5->definition()->input(0)->as(); + auto tv2_pad = tv5->definition()->input(1)->as(); + + ExprGroupPath reference_path{ + {graph.toGroup(tv2->getLogicalDomain().at(1)->definition()), + Direction::Forward}, + {graph.toGroup(tv4->getLogicalDomain().at(1)->definition()), + Direction::Forward}, + {graph.toGroup(tv4_pad->getLogicalDomain().at(1)->definition()), + Direction::Backward}, + {graph.toGroup(tv2_pad->getLogicalDomain().at(1)->definition()), + Direction::Backward}, + {graph.toGroup(tv2_pad->getLogicalDomain().at(1)->definition()), + Direction::Forward}, + {graph.toGroup(tv4_pad->getLogicalDomain().at(1)->definition()), + Direction::Forward}, + {graph.toGroup(tv4->getLogicalDomain().at(1)->definition()), + Direction::Backward}, + {graph.toGroup(tv2->getLogicalDomain().at(1)->definition()), + Direction::Backward}}; + + VALIDATE_EXPR_PATH(result, reference_path); + } +} + } // namespace nvfuser diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index a6692ec21ab..65b76f4a347 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -4756,7 +4756,7 @@ TEST_P(ResizeSchedulerTest, SliceRotateCat) { Fusion& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); - std::vector shape({-1, 128}); + std::vector shape({-1, 100}); EnableOptionsGuard enable_options_guard; EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); @@ -4782,7 +4782,7 @@ TEST_P(ResizeSchedulerTest, SliceRotateCat) { auto tv5 = cat({tv4, tv2}, 1); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({16, 128}, options); + auto t0 = at::randn({16, 100}, options); fusion.addOutput(tv5); @@ -4888,7 +4888,7 @@ TEST_P(ResizeSchedulerTest, SliceRotateCatResidual) { // slicing paths as well. For now, in order to avoid the error due // to issue #3640, use a size that is divisible by 8. // std::vector shape({16, 100}); - std::vector shape({16, 96}); + std::vector shape({16, 100}); EnableOptionsGuard enable_options_guard; EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); @@ -5903,8 +5903,6 @@ TEST_F(ResizeTest, VectorizeInnermostWithReshapeMerge) { EXPECT_EQ(tv3->getLoopDomain().back()->extent()->evaluate(), 4); } -// don't cache if the input tv is used by slice. -// https://github.com/NVIDIA/Fuser/issues/1697 TEST_F(ResizeTest, AvoidCachingSliceInput) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -5984,4 +5982,39 @@ TEST_F(ResizeTest, AvoidCachingSliceInput) { } } +TEST_F(ResizeTest, VectorizeSliceMultiplePaths) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + const int64_t size = 128; + + auto tv0 = makeContigConcreteTensor({size}); + fusion.addInput(tv0); + + auto tv1 = sin(tv0); + auto tv2 = + slice(tv1, {{IrBuilder::create(4L), IrBuilder::create(size)}}); + auto tv3 = slice( + tv1, {{IrBuilder::create(2L), IrBuilder::create(size - 2)}}); + auto tv4 = slice( + tv1, {{IrBuilder::create(0L), IrBuilder::create(size - 4)}}); + auto tv5 = add(tv2, tv3); + auto tv6 = add(tv5, tv4); + fusion.addOutput(tv6); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({size}, options); + + auto outputs = scheduleAndRun(&fusion, SchedulerType::Resize, {t0}); + testValidate(&fusion, outputs.outputs, {t0}, __LINE__, __FILE__); + + // Should be vector by a factor of 2 because of the tv3 slice. The + // spanning tree based vectorization analysis may return 4 as only + // one of the paths from tv6 to tv0 is considered. + EXPECT_EQ( + tv6->getLoopDomain().back()->getParallelType(), ParallelType::Vectorize); + EXPECT_EQ(tv6->getLoopDomain().back()->extent()->evaluate(), 2); +} + } // namespace nvfuser