diff --git a/src/relay/analysis/graph_partitioner.cc b/src/relay/analysis/graph_partitioner.cc new file mode 100644 index 000000000000..861fd58d9e5c --- /dev/null +++ b/src/relay/analysis/graph_partitioner.cc @@ -0,0 +1,334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./graph_partitioner.h" + +#include + +namespace tvm { +namespace relay { + +DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) { + DominatorTree tree; + tree.nodes.resize(graph.post_dfs_order.size(), nullptr); + // reverse topo order + for (size_t i = graph.post_dfs_order.size(); i != 0; --i) { + size_t index = i - 1; + tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]); + } + return tree; +} + +DominatorTree::Node* DominatorTree::LeastCommonAncestor(Node* lhs, Node* rhs, + OpPatternKind* edge_pattern) { + while (lhs != rhs) { + if (lhs == nullptr) return nullptr; + if (rhs == nullptr) return nullptr; + if (lhs->depth < rhs->depth) { + edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); + rhs = rhs->parent; + } else if (rhs->depth < lhs->depth) { + edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); + lhs = lhs->parent; + } else { + edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); + edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); + lhs = lhs->parent; + rhs = rhs->parent; + } + } + return lhs; +} + +DominatorTree::Node* DominatorTree::LeastCommonAncestor( + const LinkedList& input_nodes, OpPatternKind* edge_pattern) { + auto link = input_nodes.head; + if (link == nullptr) { + return nullptr; + } + auto get_node = [&](const IndexedForwardGraph::Edge& edge) { + size_t oindex = edge.node->index; + ICHECK_LT(oindex, nodes.size()); + Node* onode = nodes[oindex]; + ICHECK(onode != nullptr); + return onode; + }; + Node* parent = get_node(link->value); + *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); + link = link->next; + for (; link != nullptr; link = link->next) { + parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern); + *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); + } + return parent; +} + +DominatorTree::Node* DominatorTree::GetNode(support::Arena* arena, + IndexedForwardGraph::Node* gnode) { + Node* tnode = arena->make(); + tnode->gnode = gnode; + if (gnode->extern_ref) { + tnode->depth = 1; + tnode->parent = nullptr; + tnode->pattern = kOpaque; + } else { + // find the LCAs of all outputs. + OpPatternKind pattern = kElemWise; + Node* parent = LeastCommonAncestor(gnode->outputs, &pattern); + tnode->depth = parent ? parent->depth + 1 : 1; + tnode->parent = parent; + tnode->pattern = pattern; + } + return tnode; +} + +std::vector GraphPartitioner::Partition( + const IndexedForwardGraph& graph) { + this->InitGroups(graph); + if (opt_level_ == 0) return std::move(groups_); + // get post dominator tree + auto post_dom_tree = DominatorTree::PostDom(arena_, graph); + // run fusion algorithm. + for (int phase = 0; phase < 3; ++phase) { + this->RunFuse(graph, post_dom_tree, phase); + } + return std::move(groups_); +} + +GraphPartitioner::Group* GraphPartitioner::Group::FindRoot() { + // fast path + if (this->parent == nullptr) return this; + // slow path with path compression. + Group* root = this; + while (root->parent != nullptr) { + root = root->parent; + } + for (Group* p = this; p != root;) { + Group* parent = p->parent; + p->parent = root; + p = parent; + } + return root; +} + +template +bool GraphPartitioner::CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, + F fcond) { + if (visited_.count(src)) return true; + visited_.insert(src); + Group* gnode = groups_[src->index]; + ICHECK(gnode != nullptr); + gnode = gnode->FindRoot(); + if (!fcond(gnode->pattern, src == sink)) return false; + if (src == sink) return true; + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + if (!CheckPath_(link->value.node, sink, fcond)) return false; + } + return true; +} + +template +bool GraphPartitioner::CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, + F fcond) { + ICHECK(!src->extern_ref); + visited_.clear(); + ICHECK(src != sink); + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + if (!CheckPath_(link->value.node, sink, fcond)) return false; + } + return true; +} + +OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { + if (lhs > relay::kBroadcast && rhs > relay::kBroadcast) { + LOG(FATAL) << "Cannot merge two complex group together"; + } + if (lhs > rhs) return lhs; + return rhs; +} + +void GraphPartitioner::MergeFromTo(Group* child, Group* parent) { + child = child->FindRoot(); + parent = parent->FindRoot(); + if (child == parent) return; + // update the number of nodes of the parent group + parent->num_nodes += child->num_nodes; + child->parent = parent; + // update anchor ref and pattern + if (child->anchor_ref != nullptr) { + ICHECK(parent->anchor_ref == nullptr); + parent->anchor_ref = child->anchor_ref; + parent->pattern = CombinePattern(child->pattern, parent->pattern); + } +} + +void GraphPartitioner::CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, + Group* target) { + if (src == sink) return; + if (visited_.count(src)) return; + visited_.insert(src); + Group* gnode = groups_[src->index]; + ICHECK(gnode != nullptr); + // merge the current group to the parent if possible. + MergeFromTo(gnode, target); + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + CommitFuse_(link->value.node, sink, target); + } +} + +void GraphPartitioner::CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { + Group* target = groups_[sink->index]; + visited_.clear(); + ICHECK(src != sink); + CommitFuse_(src, sink, target); +} + +size_t GraphPartitioner::CountNodesUptoSink_(IndexedForwardGraph::Node* src, + IndexedForwardGraph::Node* sink) { + if (src == sink || visited_.count(src)) return 0; + visited_.insert(src); + Group* gnode = groups_[src->index]; + ICHECK(gnode != nullptr); + auto sum = gnode->num_nodes; + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + sum += CountNodesUptoSink_(link->value.node, sink); + } + return sum; +} + +size_t GraphPartitioner::CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child, + IndexedForwardGraph::Node* dom_parent) { + Group* target = groups_[dom_parent->index]; + visited_.clear(); + ICHECK(child != dom_parent); + return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent); +} + +void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) { + groups_.resize(graph.post_dfs_order.size()); + for (size_t nid = 0; nid < groups_.size(); ++nid) { + const auto* graph_node = graph.post_dfs_order[nid]; + auto* group_node = arena_->make(); + group_node->pattern = graph_node->pattern; + group_node->root_ref = graph_node->ref; + // set anchor ref if necessary. + if (group_node->pattern == relay::kOutEWiseFusable) { + group_node->anchor_ref = graph_node->ref; + } + groups_[nid] = group_node; + } +} + +void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, // + const DominatorTree& post_dom_tree, // + int phase) { + for (size_t nid = 0; nid < groups_.size(); ++nid) { + // the group of current node has been specified already. + auto* graph_node = graph.post_dfs_order[nid]; + auto* dom_node = post_dom_tree.nodes[nid]; + Group* group_node = groups_[nid]; + ICHECK(group_node != nullptr); + // no actions for opaque nodes + if (group_node->pattern == kOpaque) continue; + // no actions needed if the current node have no dominator + if (dom_node->parent == nullptr) continue; + ICHECK(!graph_node->extern_ref); + size_t dom_parent_gindex = dom_node->parent->gnode->index; + + // refuse the fusion if too many ops are going to be fused together + if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_) + continue; + + if (phase == 2) { + // Fuse injective ops into intermediate tuples, if any + if (group_node->pattern > relay::kInjective) continue; + Group* dom_parent_group = groups_[dom_parent_gindex]; + Group* dom_root_group = dom_parent_group->FindRoot(); + // If dom node group has a tuple as its root, we do not fuse tuple fields into it + if (dom_root_group->pattern == relay::kTuple) continue; + if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= relay::kInjective) { + // Now we know the tuple has been fused into subsequent injective ops + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; + // dom_root_group can also be tuple, as in inception layers + // CheckPath is needed to avoid fusing two intermediate tuples + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } + continue; + } + + // Skip if current node is already fused to the parent. + if (groups_[dom_parent_gindex] != nullptr && + group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) { + continue; + } + // Do not fuse into tuple for now + if (groups_[dom_parent_gindex]->pattern == kTuple) continue; + // Try to fuse current node to its post-dominator. + if (group_node->pattern == kOutEWiseFusable) { + if (phase != 0) continue; + // Path for OutEWiseFusable: conv2d + // Check if the dominator relation is elemwise. + if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) { + ICHECK(dom_node->parent->gnode != nullptr); + // The fuse can be executed if all the intermediate ops are still broadcast. + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; }; + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } + } else if (group_node->pattern <= kBroadcast) { + // Pre-condition: can only be fused to parent which is injective or reduction. + if (dom_node->parent != nullptr && + (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) { + // Check if all the intermediate ops are still broadcast. + // The final terminal node can already be fused to a OutEWiseFusable group. + auto fcond = [](OpPatternKind kind, bool is_sink) { + if (!is_sink) { + // Elemwise, broadcast, and injective ops on the parallel branches + // are allowed be fused to the elemwise/broadcast anchor. + return kind <= kInjective; + } else { + return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective || + kind == kOutEWiseFusable); + } + }; + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } + } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) { + // defer injective fusion to second phase. + // so conv2d always finishes fusing. + if (phase != 1) continue; + // Check if all path are injective. + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } else { + // do nothing. + ICHECK(group_node->pattern == kCommReduce); + } + } +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/analysis/graph_partitioner.h b/src/relay/analysis/graph_partitioner.h new file mode 100644 index 000000000000..9433aafa119d --- /dev/null +++ b/src/relay/analysis/graph_partitioner.h @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/graph_partitioner.h + * \brief The helper function for op fusion. + */ + +#ifndef TVM_RELAY_ANALYSIS_GRAPH_PARTITIONER_H_ +#define TVM_RELAY_ANALYSIS_GRAPH_PARTITIONER_H_ + +#include + +#include +#include +#include + +#include "../../support/arena.h" + +namespace tvm { +namespace relay { + +using support::LinkedList; +using support::LinkNode; + +/*! + * \brief Indexed data flow graph in forward direction. + * This is a temporary data structure used for operator fusion analysis. + * + * This data structure only captures the dataflow fragment and + * could ignore blocks like let by simply ordering each dataflow block + * and mark the output node as extern_ref; + */ +class IndexedForwardGraph { + public: + struct Node; + /*! + * The forward edge in the dataflow graph. + */ + struct Edge { + /*! \brief The corresponding node */ + Node* node{nullptr}; + /*! \brief The respective pattern of this op */ + OpPatternKind pattern{kOpaque}; + }; + /*! \brief A node in the graph. */ + struct Node { + /*! \brief weak reference to the corresponding edge. */ + const tvm::Object* ref{nullptr}; + /*! \brief The index of the node in topological order. */ + size_t index{0}; + /*! \brief Whether this node is referenced by external source */ + bool extern_ref{false}; + /*! \brief The general pattern in the node */ + OpPatternKind pattern{kOpaque}; + /*! \brief The outputs of the node. */ + LinkedList outputs; + }; + /*! \brief The node map that maps node to graph */ + std::unordered_map node_map; + /*! \brief All the nodes in post DFS order */ + std::vector post_dfs_order; + + /*! \brief Dump the graph into string. */ + void DebugDump() { + std::ostringstream os; + for (size_t i = 0; i < post_dfs_order.size(); ++i) { + Node* node = post_dfs_order[i]; + os << "node[" << i << "], " << GetRef(node->ref) << " outputs=["; + for (auto* link = node->outputs.head; link != nullptr; link = link->next) { + os << link->value.node->index << ", "; + } + os << "]\n"; + } + LOG(INFO) << os.str(); + } +}; + +/*! + * \brief Dominator tree that represent domination or + * post domination relation of the node. + */ +class DominatorTree { + public: + /*! + * \brief A node in the dominator tree. + */ + struct Node { + /*! \brief The node in the tree */ + IndexedForwardGraph::Node* gnode{nullptr}; + /*! \brief parent of the tree */ + Node* parent{nullptr}; + /*! \brief current depth*/ + int depth{0}; + /*! \brief aggregated pattern to parent */ + OpPatternKind pattern{kOpaque}; + }; + // index -> node. + std::vector nodes; + /*! + * \brief compute a post dominator relation for a given dataflow graph. + * \param arena The arena used for node allocation. + * \param graph The graph to be analyzed. + * \return The dominator tree of the graph. + * \note This algorithm makes use of the fact that graph is DAG, + * and runs a single pass algorithm via LCA (Least Common Ancestor) + */ + static DominatorTree PostDom(support::Arena* arena, const IndexedForwardGraph& graph); + + private: + // Combine pattern together. + inline static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { + if (lhs > rhs) return lhs; + return rhs; + } + /*! + * \brief Find the least common ancestor of the two nodes. + * \param lhs The left node. + * \param rhs The right node. + * \param edge_pattern + * The combined edge pattern across all the parents. + * \return The least common ancestor of the two. + */ + static Node* LeastCommonAncestor(Node* lhs, Node* rhs, OpPatternKind* edge_pattern); + /*! + * \brief Find the least common ancestor of a list of nodes. + * \param nodes the nodes. + * \param edge_pattern + * The combined edge pattern across all the parents. + * \return The least common ancestor of all nodes. + */ + Node* LeastCommonAncestor(const LinkedList& input_nodes, + OpPatternKind* edge_pattern); + + /*! + * \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node. + * \param arena The Arena. + * \param gnode An IndexedForwardGraph Node. + * \return The DominatorTree Node. + */ + Node* GetNode(support::Arena* arena, IndexedForwardGraph::Node* gnode); +}; + +/*! + * \brief A partition of the graph marked by union find data structure. + */ +class GraphPartitioner { + public: + explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth) + : arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {} + /*! + * \brief Group as a union find data structure. + */ + struct Group { + /*! \brief The parent in the union find data structure. */ + Group* parent{nullptr}; + /*! \brief The pattern of the group */ + OpPatternKind pattern; + /*! \brief reference to the root node. */ + const tvm::Object* root_ref{nullptr}; + /*! + * \brief Reference to the anchor node, + * this field is not nullptr only if pattern is kOutEWiseFusable. + */ + const tvm::Object* anchor_ref{nullptr}; + /*! + * \brief The number of nodes belonging to this group + */ + uint32_t num_nodes{1}; + + /*! \brief Optional attributes to annotate the grouped function. */ + runtime::Map attrs; + /*! + * \brief Find the group root, perform path compression + * \return The root type node. + */ + Group* FindRoot(); + }; + /*! + * \brief Partition a graph. + * \return group assignments of each node. + */ + std::vector Partition(const IndexedForwardGraph& graph); + + private: + /*! \brief The internal arena for temporary space. */ + support::Arena* arena_; + /*! \brief optimization level for fuse operation. */ + int opt_level_; + /*! \brief The maximum number of operations in one fused function */ + size_t max_fuse_depth_; + /*! \brief The internal groups. */ + std::vector groups_; + /*! \brief internal field used for deduplication */ + std::unordered_set visited_; + // Internal implementation of CheckPath + template + bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond); + + /*! + * \brief Check all the node and edge pattern + * between src and sink satisfies fcond. + * + * src is not checked. + * + * \param src The source node. + * \param sink The termination node. + * \param fcond The condition to be checked. + * \tparam F the condition function, with signature + * \note sink must be a post-dominator of src. + */ + template + bool CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond); + + /*! + * \brief Merge the child group to the parent. + * \param child The child group. + * \param parent The parent group. + */ + void MergeFromTo(Group* child, Group* parent); + + // Internal implementation of CommitFuse + void CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, Group* target); + + /*! + * \brief Commit fusion operation. + * \param src The source node. + * \param sink The termination node. + * \note sink must be a post-dominator of src. + */ + void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink); + + size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink); + + // Count the number of nodes in a fused subgraph if child is additionally fused. + // dom_parent is already known to be a part of the subgraph. + // For a diamond structure, there can be multiple paths connecting child and dom_parent. + // All intermediate nodes between child and dom_parent are taken into account. + // Since dom_parent can itself be an intermediate node in the subgraph, calling FindRoot() + // is important for correct calculation. + size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child, + IndexedForwardGraph::Node* dom_parent); + + // Initialize the groups. + void InitGroups(const IndexedForwardGraph& graph); + + // execute the fusion algorithm. + void RunFuse(const IndexedForwardGraph& graph, const DominatorTree& post_dom_tree, int phase); +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_ANALYSIS_GRAPH_PARTITIONER_H_ diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index afa60f1bb4e5..1fb857cb1cb3 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -32,6 +32,7 @@ #include #include "../../support/arena.h" +#include "../analysis/graph_partitioner.h" #include "../op/annotation/annotation.h" #include "./pass_utils.h" #include "./pattern_utils.h" @@ -88,72 +89,16 @@ static const Op& stop_fusion_op = Op::Get("annotation.stop_fusion"); TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.max_depth", Integer); TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.link_params", Bool); -/*! - * \brief Indexed data flow graph in forward direction. - * This is a temporary data structure used for operator fusion analysis. - * - * This data structure only captures the dataflow fragment and - * could ignore blocks like let by simply ordering each dataflow block - * and mark the output node as extern_ref; - */ -class IndexedForwardGraph { +// Creator of post dominator tree of the dataflow +class IndexedForwardGraphCreator : private ExprVisitor { public: - struct Node; - /*! - * The forward edge in the dataflow graph. - */ - struct Edge { - /*! \brief The corresponding node */ - Node* node{nullptr}; - /*! \brief The respective pattern of this op */ - OpPatternKind pattern{kOpaque}; - }; - /*! \brief A node in the graph. */ - struct Node { - /*! \brief weak reference to the corresponding edge. */ - const tvm::Object* ref{nullptr}; - /*! \brief The index of the node in topological order. */ - size_t index{0}; - /*! \brief Whether this node is referenced by external source */ - bool extern_ref{false}; - /*! \brief The general pattern in the node */ - OpPatternKind pattern{kOpaque}; - /*! \brief The outputs of the node. */ - LinkedList outputs; - }; - /*! \brief The node map that maps node to graph */ - std::unordered_map node_map; - /*! \brief All the nodes in post DFS order */ - std::vector post_dfs_order; - - /*! \brief Dump the graph into string. */ - void DebugDump() { - std::ostringstream os; - for (size_t i = 0; i < post_dfs_order.size(); ++i) { - Node* node = post_dfs_order[i]; - os << "node[" << i << "], " << GetRef(node->ref) << " outputs=["; - for (auto* link = node->outputs.head; link != nullptr; link = link->next) { - os << link->value.node->index << ", "; - } - os << "]\n"; - } - LOG(INFO) << os.str(); + static IndexedForwardGraph Create(support::Arena* arena, const Expr& body) { + IndexedForwardGraphCreator creator(arena); + return creator.Prepare(body); } - /*! - * \brief create a indexed forward graph. - * \param arena The arena used for data allocation. - * \param body The body of the expression to create a graph. - */ - static IndexedForwardGraph Create(support::Arena* arena, const Expr& body); private: - class Creator; -}; - -// Creator of post dominator tree of the dataflow -class IndexedForwardGraph::Creator : private ExprVisitor { - public: - explicit Creator(support::Arena* arena) : arena_(arena) {} + explicit IndexedForwardGraphCreator(support::Arena* arena) : arena_(arena) {} IndexedForwardGraph Prepare(const Expr& body) { this->Update(body, nullptr, kOpaque); @@ -213,7 +158,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const ConstantNode* op) final { this->AddNode(op); - Node* node = graph_.node_map.at(op); + IndexedForwardGraph::Node* node = graph_.node_map.at(op); DataType dtype = DataType(op->data->dtype); // This rule must be consistent with code generator. bool is_simple_const = @@ -230,7 +175,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const CallNode* call) final { ICHECK(graph_.node_map.count(call)); - Node* node = graph_.node_map.at(call); + IndexedForwardGraph::Node* node = graph_.node_map.at(call); static auto fpattern = Op::GetAttrMap("TOpPattern"); // Now we set the pattern of this call. // @@ -274,7 +219,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const TupleNode* op) final { ICHECK(graph_.node_map.count(op)); - Node* tuple_node = graph_.node_map.at(op); + IndexedForwardGraph::Node* tuple_node = graph_.node_map.at(op); tuple_node->pattern = kTuple; for (const Expr& field : op->fields) { if (field->checked_type().as()) { @@ -306,7 +251,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->Update(op->tuple, nullptr, kOpaque); } else { ICHECK(graph_.node_map.count(op)); - Node* node = graph_.node_map.at(op); + IndexedForwardGraph::Node* node = graph_.node_map.at(op); node->pattern = kInjective; this->Update(op->tuple, node, kInjective); } @@ -372,443 +317,6 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } }; -IndexedForwardGraph IndexedForwardGraph::Create(support::Arena* arena, const Expr& body) { - return Creator(arena).Prepare(body); -} - -/*! - * \brief Dominator tree that represent domination or - * post domination relation of the node. - */ -class DominatorTree { - public: - /*! - * \brief A node in the dominator tree. - */ - struct Node { - /*! \brief The node in the tree */ - IndexedForwardGraph::Node* gnode{nullptr}; - /*! \brief parent of the tree */ - Node* parent{nullptr}; - /*! \brief current depth*/ - int depth{0}; - /*! \brief aggregated pattern to parent */ - OpPatternKind pattern{kOpaque}; - }; - // index -> node. - std::vector nodes; - /*! - * \brief compute a post dominator relation for a given dataflow graph. - * \param arena The arena used for node allocation. - * \param graph The graph to be analyzed. - * \return The dominator tree of the graph. - * \note This algorithm makes use of the fact that graph is DAG, - * and runs a single pass algorithm via LCA (Least Common Ancestor) - */ - static DominatorTree PostDom(support::Arena* arena, const IndexedForwardGraph& graph); - - private: - // Combine pattern together. - static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { - if (lhs > rhs) return lhs; - return rhs; - } - /*! - * \brief Find the least common ancestor of the two nodes. - * \param lhs The left node. - * \param rhs The right node. - * \param edge_pattern - * The combined edge pattern across all the parents. - * \return The least common ancestor of the two. - */ - static Node* LeastCommonAncestor(Node* lhs, Node* rhs, OpPatternKind* edge_pattern) { - while (lhs != rhs) { - if (lhs == nullptr) return nullptr; - if (rhs == nullptr) return nullptr; - if (lhs->depth < rhs->depth) { - edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); - rhs = rhs->parent; - } else if (rhs->depth < lhs->depth) { - edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); - lhs = lhs->parent; - } else { - edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); - edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); - lhs = lhs->parent; - rhs = rhs->parent; - } - } - return lhs; - } - /*! - * \brief Find the least common ancestor of a list of nodes. - * \param nodes the nodes. - * \param edge_pattern - * The combined edge pattern across all the parents. - * \return The least common ancestor of all nodes. - */ - Node* LeastCommonAncestor(const LinkedList& input_nodes, - OpPatternKind* edge_pattern) { - auto link = input_nodes.head; - if (link == nullptr) { - return nullptr; - } - auto get_node = [&](const IndexedForwardGraph::Edge& edge) { - size_t oindex = edge.node->index; - ICHECK_LT(oindex, nodes.size()); - Node* onode = nodes[oindex]; - ICHECK(onode != nullptr); - return onode; - }; - Node* parent = get_node(link->value); - *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); - link = link->next; - for (; link != nullptr; link = link->next) { - parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern); - *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); - } - return parent; - } - /*! - * \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node. - * \param arena The Arena. - * \param gnode An IndexedForwardGraph Node. - * \return The DominatorTree Node. - */ - Node* GetNode(support::Arena* arena, IndexedForwardGraph::Node* gnode) { - Node* tnode = arena->make(); - tnode->gnode = gnode; - if (gnode->extern_ref) { - tnode->depth = 1; - tnode->parent = nullptr; - tnode->pattern = kOpaque; - } else { - // find the LCAs of all outputs. - OpPatternKind pattern = kElemWise; - Node* parent = LeastCommonAncestor(gnode->outputs, &pattern); - tnode->depth = parent ? parent->depth + 1 : 1; - tnode->parent = parent; - tnode->pattern = pattern; - } - return tnode; - } -}; - -DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) { - DominatorTree tree; - tree.nodes.resize(graph.post_dfs_order.size(), nullptr); - // reverse topo order - for (size_t i = graph.post_dfs_order.size(); i != 0; --i) { - size_t index = i - 1; - tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]); - } - return tree; -} - -/*! - * \brief A partition of the graph marked by union find data structure. - */ -class GraphPartitioner { - public: - explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth) - : arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {} - /*! - * \brief Group as a union find data structure. - */ - struct Group { - /*! \brief The parent in the union find data structure. */ - Group* parent{nullptr}; - /*! \brief The pattern of the group */ - OpPatternKind pattern; - /*! \brief reference to the root node. */ - const tvm::Object* root_ref{nullptr}; - /*! - * \brief Reference to the anchor node, - * this field is not nullptr only if pattern is kOutEWiseFusable. - */ - const tvm::Object* anchor_ref{nullptr}; - /*! - * \brief Find the group root, perform path compression - * \return The root type node. - */ - Group* FindRoot() { - // fast path - if (this->parent == nullptr) return this; - // slow path with path compression. - Group* root = this; - while (root->parent != nullptr) { - root = root->parent; - } - for (Group* p = this; p != root;) { - Group* parent = p->parent; - p->parent = root; - p = parent; - } - return root; - } - - /*! - * \brief The number of nodes belonging to this group - */ - uint32_t num_nodes{1}; - }; - /*! - * \brief Partition a graph. - * \return group assignments of each node. - */ - std::vector Partition(const IndexedForwardGraph& graph); - - private: - /*! \brief The internal arena for temporary space. */ - support::Arena* arena_; - /*! \brief optimization level for fuse operation. */ - int opt_level_; - /*! \brief The maximum number of operations in one fused function */ - size_t max_fuse_depth_; - /*! \brief The internal groups. */ - std::vector groups_; - /*! \brief internal field used for deduplication */ - std::unordered_set visited_; - // Internal implelementation of CheckPath - template - bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) { - if (visited_.count(src)) return true; - visited_.insert(src); - Group* gnode = groups_[src->index]; - ICHECK(gnode != nullptr); - gnode = gnode->FindRoot(); - if (!fcond(gnode->pattern, src == sink)) return false; - if (src == sink) return true; - for (auto link = src->outputs.head; link != nullptr; link = link->next) { - if (!CheckPath_(link->value.node, sink, fcond)) return false; - } - return true; - } - /*! - * \brief Check all the node and edge pattern - * between src and sink satisfies fcond. - * - * src is not checked. - * - * \param src The source node. - * \param sink The termination node. - * \param fcond The condition to be checked. - * \tparam F the condition function, with signature - * \note sink must be a post-dominator of src. - */ - template - bool CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) { - ICHECK(!src->extern_ref); - visited_.clear(); - ICHECK(src != sink); - for (auto link = src->outputs.head; link != nullptr; link = link->next) { - if (!CheckPath_(link->value.node, sink, fcond)) return false; - } - return true; - } - // Combine two patterns together. - static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { - if (lhs > kBroadcast && rhs > kBroadcast) { - LOG(FATAL) << "Cannot merge two complex group together"; - } - if (lhs > rhs) return lhs; - return rhs; - } - /*! - * \brief Merge the child group to the parent. - * \param child The child group. - * \param parent The parent group. - */ - void MergeFromTo(Group* child, Group* parent) { - child = child->FindRoot(); - parent = parent->FindRoot(); - if (child == parent) return; - // update the number of nodes of the parent group - parent->num_nodes += child->num_nodes; - child->parent = parent; - // update anchor ref and pattern - if (child->anchor_ref != nullptr) { - ICHECK(parent->anchor_ref == nullptr); - parent->anchor_ref = child->anchor_ref; - parent->pattern = CombinePattern(child->pattern, parent->pattern); - } - } - // Internal implelementation of CommitFuse - void CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, Group* target) { - if (src == sink) return; - if (visited_.count(src)) return; - visited_.insert(src); - Group* gnode = groups_[src->index]; - ICHECK(gnode != nullptr); - // merge the current group to the parent if possible. - MergeFromTo(gnode, target); - for (auto link = src->outputs.head; link != nullptr; link = link->next) { - CommitFuse_(link->value.node, sink, target); - } - } - /*! - * \brief Commit fusion operation. - * \param src The source node. - * \param sink The termination node. - * \note sink must be a post-dominator of src. - */ - void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { - Group* target = groups_[sink->index]; - visited_.clear(); - ICHECK(src != sink); - CommitFuse_(src, sink, target); - } - - size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { - if (src == sink || visited_.count(src)) return 0; - visited_.insert(src); - Group* gnode = groups_[src->index]; - ICHECK(gnode != nullptr); - auto sum = gnode->num_nodes; - for (auto link = src->outputs.head; link != nullptr; link = link->next) { - sum += CountNodesUptoSink_(link->value.node, sink); - } - return sum; - } - - // Count the number of nodes in a fused subgraph if child is additionaly fused. - // dom_parent is already known to be a part of the subgraph. - // For a diamond structure, there can be multiple paths connecting child and dom_parent. - // All intermediate nodes between child and dom_parent are taken into account. - // Since dom_parent can itself be an intermediate node in the subgraph, calling FindRoot() - // is important for correct calculation. - size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child, - IndexedForwardGraph::Node* dom_parent) { - Group* target = groups_[dom_parent->index]; - visited_.clear(); - ICHECK(child != dom_parent); - return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent); - } - - // Initialize the groups. - void InitGroups(const IndexedForwardGraph& graph) { - groups_.resize(graph.post_dfs_order.size()); - for (size_t nid = 0; nid < groups_.size(); ++nid) { - const auto* graph_node = graph.post_dfs_order[nid]; - auto* group_node = arena_->make(); - group_node->pattern = graph_node->pattern; - group_node->root_ref = graph_node->ref; - // set anchor ref if necessary. - if (group_node->pattern == kOutEWiseFusable) { - group_node->anchor_ref = graph_node->ref; - } - groups_[nid] = group_node; - } - } - - // execute the fusion algorithm. - void RunFuse(const IndexedForwardGraph& graph, const DominatorTree& post_dom_tree, int phase) { - for (size_t nid = 0; nid < groups_.size(); ++nid) { - // the group of current node has been specified already. - auto* graph_node = graph.post_dfs_order[nid]; - auto* dom_node = post_dom_tree.nodes[nid]; - Group* group_node = groups_[nid]; - ICHECK(group_node != nullptr); - // no actions for opaque nodes - if (group_node->pattern == kOpaque) continue; - // no actions needed if the current node have no dominator - if (dom_node->parent == nullptr) continue; - ICHECK(!graph_node->extern_ref); - size_t dom_parent_gindex = dom_node->parent->gnode->index; - - // refuse the fusion if too many ops are going to be fused together - if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_) - continue; - - if (phase == 2) { - // Fuse injective ops into intermediate tuples, if any - if (group_node->pattern > kInjective) continue; - Group* dom_parent_group = groups_[dom_parent_gindex]; - Group* dom_root_group = dom_parent_group->FindRoot(); - // If dom node group has a tuple as its root, we do not fuse tuple fields into it - if (dom_root_group->pattern == kTuple) continue; - if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) { - // Now we know the tuple has been fused into subsequent injective ops - auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; - // dom_root_group can also be tuple, as in inception layers - // CheckPath is needed to avoid fusing two intermediate tuples - if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { - CommitFuse(graph_node, dom_node->parent->gnode); - } - } - continue; - } - - // Skip if current node is already fused to the parent. - if (groups_[dom_parent_gindex] != nullptr && - group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) { - continue; - } - // Do not fuse into tuple for now - if (groups_[dom_parent_gindex]->pattern == kTuple) continue; - // Try to fuse current node to its post-dominator. - if (group_node->pattern == kOutEWiseFusable) { - if (phase != 0) continue; - // Path for OutEWiseFusable: conv2d - // Check if the dominator relation is elemwise. - if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) { - ICHECK(dom_node->parent->gnode != nullptr); - // The fuse can be executed if all the intermediate ops are still broadcast. - auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; }; - if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { - CommitFuse(graph_node, dom_node->parent->gnode); - } - } - } else if (group_node->pattern <= kBroadcast) { - // Pre-condition: can only be fused to parent which is injective or reduction. - if (dom_node->parent != nullptr && - (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) { - // Check if all the intermediate ops are still broadcast. - // The final terminal node can already be fused to a OutEWiseFusable group. - auto fcond = [](OpPatternKind kind, bool is_sink) { - if (!is_sink) { - // Elemwise, broadcast, and injective ops on the parallel branches - // are allowed be fused to the elemwise/broadcast anchor. - return kind <= kInjective; - } else { - return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective || - kind == kOutEWiseFusable); - } - }; - if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { - CommitFuse(graph_node, dom_node->parent->gnode); - } - } - } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) { - // defer injective fusion to second phase. - // so conv2d always finishes fusing. - if (phase != 1) continue; - // Check if all path are injective. - auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; - if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { - CommitFuse(graph_node, dom_node->parent->gnode); - } - } else { - // do nothing. - ICHECK(group_node->pattern == kCommReduce); - } - } - } -}; - -std::vector GraphPartitioner::Partition( - const IndexedForwardGraph& graph) { - this->InitGroups(graph); - if (opt_level_ == 0) return std::move(groups_); - // get post dominator tree - auto post_dom_tree = DominatorTree::PostDom(arena_, graph); - // run fusion algorithm. - for (int phase = 0; phase < 3; ++phase) { - this->RunFuse(graph, post_dom_tree, phase); - } - return std::move(groups_); -} - class FuseMutator : private MixedModeMutator { public: FuseMutator(int fuse_opt_level, size_t max_fuse_depth, bool link_params) @@ -825,7 +333,7 @@ class FuseMutator : private MixedModeMutator { // Run the transform Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth, bool link_params) { // setup the group map. - auto graph = IndexedForwardGraph::Create(&arena_, body); + auto graph = IndexedForwardGraphCreator::Create(&arena_, body); auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth).Partition(graph); for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { ICHECK(graph.post_dfs_order[nid]->ref != nullptr);