From 6a7e7f3e8315dff9170e0a689d97c09912167cd3 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 21 Feb 2020 19:14:45 +0000 Subject: [PATCH 1/3] call graph for relay --- python/tvm/relay/__init__.py | 4 + python/tvm/relay/call_graph.py | 143 ++++++++ src/relay/pass/call_graph.cc | 339 +++++++++++++++++ src/relay/pass/call_graph.h | 509 ++++++++++++++++++++++++++ tests/python/relay/test_call_graph.py | 150 ++++++++ 5 files changed, 1145 insertions(+) create mode 100644 python/tvm/relay/call_graph.py create mode 100644 src/relay/pass/call_graph.cc create mode 100644 src/relay/pass/call_graph.h create mode 100644 tests/python/relay/test_call_graph.py diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 0df3747a93b1..2ad210e7d109 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -19,6 +19,7 @@ import os from sys import setrecursionlimit from ..api import register_func +from . import call_graph from . import base from . import ty from . import expr @@ -141,3 +142,6 @@ # Feature Feature = feature.Feature + +# CallGraph +CallGraph = call_graph.CallGraph diff --git a/python/tvm/relay/call_graph.py b/python/tvm/relay/call_graph.py new file mode 100644 index 000000000000..104ccda8d585 --- /dev/null +++ b/python/tvm/relay/call_graph.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import +"""Call graph used in Relay.""" + +from tvm.ir import IRModule +from .base import Object +from .expr import GlobalVar +from . import _analysis + + +class CallGraph(Object): + """Class to represent a call graph.""" + + def __init__(self, module): + """Construct a call graph. + + Parameters + ---------- + module : tvm.ir.IRModule + The IR module used to create a call graph + + Returns + ------- + call_graph: CallGraph + A constructed call graph. + """ + self.__init_handle_by_constructor__(_analysis.CallGraph, module) + + @property + def module(self): + """Return the contained Relay IR module. + + Parameters + ---------- + None + + Returns + ------- + ret : tvm.ir.IRModule + The contained IRModule + """ + return _analysis.GetModule(self) + + def ref_count(self, var): + """Return the number of references to the global var + + Parameters + ---------- + var : Union[String, tvm.relay.GlobalVar] + + Returns + ------- + ret : int + The number reference to the global var + """ + var = self._get_global_var(var) + return _analysis.GetRefCountGlobalVar(self, var) + + def global_call_count(self, var): + """Return the number of global function calls from a given global var. + + Parameters + ---------- + var : Union[String, tvm.relay.GlobalVar] + + Returns + ------- + ret : int + The number of global function calls from the given var. + """ + var = self._get_global_var(var) + return _analysis.GetGlobalVarCallCount(self, var) + + def is_recursive(self, var): + """Return the number of global function calls from a given global var. + + Parameters + ---------- + var : Union[String, tvm.relay.GlobalVar] + + Returns + ------- + ret : Boolean + If the function corresponding to var is recurisve. + """ + var = self._get_global_var(var) + return _analysis.IsRecursive(self, var) + + def _get_global_var(self, var): + """Return the global var using a given name or GlobalVar. + + Parameters + ---------- + var : Union[String, tvm.relay.GlobalVar] + + Returns + ------- + ret : tvm.relay.GlobalVar + The global var. + """ + if isinstance(var, str): + mod = self.module + var = mod.get_global_var(var) + + if isinstance(var, GlobalVar): + return var + else: + raise TypeError("var should be either a string or GlobalVar") + + def __str__(self): + """Print the call graph in the topological order.""" + return _analysis.PrintCallGraph(self) + + def __getitem__(self, var): + """Lookup a call graph of a global function by name or by variable. + + Parameters + ---------- + var: Union[String, tvm.relay.GlobalVar] + The name or global variable. + + Returns + ------- + ret : String + The call graph represented in string. + """ + var = self._get_global_var(var) + return _analysis.GetCallGraphGlobalVar(self, var) diff --git a/src/relay/pass/call_graph.cc b/src/relay/pass/call_graph.cc new file mode 100644 index 000000000000..5a4b6a91c04a --- /dev/null +++ b/src/relay/pass/call_graph.cc @@ -0,0 +1,339 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/pass/call_graph.cc + * \brief Implementation of APIs to handle the call graph of a Relay module. + */ + +#include "call_graph.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +CallGraph::CallGraph(IRModule module) { + auto n = make_object(); + n->module = std::move(module); + auto gvar_funcs = n->module->functions; + for (const auto& it : gvar_funcs) { + if (const auto* fn = it.second.as()) { + auto func = GetRef(fn); + // Add the global function to gradually build up the call graph. + n->AddToCallGraph(it.first, func); + } + } + data_ = std::move(n); +} + +void CallGraphNode::AddToCallGraph(const GlobalVar& gv, const Function& func) { + CHECK(func.defined() && gv.defined()); + // Add the current global function as an entry to the call grpah. + CallGraphEntryNode* cg_node = LookupGlobalVar(gv); + + // Only GlobalVar nodes need to be handled in a function. It indicates that + // the global function of a callee is called by the function that is being + // processed. An edge will be added from the current global function, cg_node, + // to the node that contains the found callee GlobalVarNode. + // + // This is the major overhead for constructing a call graph because the + // post-order visitor will visit each AST node of the current function to + // figure out the dependencies between functions. + PostOrderVisit(func, [&](const Expr& expr) { + if (const GlobalVarNode* gvn = expr.as()) { + auto callee = GetRef(gvn); + cg_node->AddCalledGlobal(LookupGlobalVar(callee)); + } + }); +} + +const CallGraphEntryNode* CallGraphNode::operator[](const GlobalVar& gv) const { + const_iterator cit = call_graph_.find(gv); + CHECK(cit != call_graph_.end()) + << "GlobalVar " << gv->name_hint << " not found in the call graph!"; + return cit->second.get(); +} + +CallGraphEntryNode* CallGraphNode::operator[](const GlobalVar& gv) { + const_iterator cit = call_graph_.find(gv); + CHECK(cit != call_graph_.end()) + << "GlobalVar " << gv->name_hint << " not found in the call graph!"; + return cit->second.get(); +} + +// Query the existence of a GlobalVar in the call graph. It creates an entry if +// there is no such a node available. +CallGraphEntryNode* CallGraphNode::LookupGlobalVar(const GlobalVar& gv) { + CHECK(gv.defined()); + + // This inserts an element to the call graph if it is not there yet. + auto& call_graph_node = call_graph_[gv]; + if (call_graph_node) return call_graph_node.get(); + + CHECK(module->ContainGlobalVar(gv->name_hint)) + << "GlobalVar " << gv->name_hint << " not found in the current ir module"; + + // Create the node for the inserted entry. + call_graph_node = std::unique_ptr(new CallGraphEntryNode(gv)); + return call_graph_node.get(); +} + +void CallGraphNode::Print(std::ostream& os) const { + // Print the call graph in the topological order. + std::vector nodes = TopologicalOrder(); + for (const auto* cgn : nodes) { + cgn->Print(os); + } +} + +GlobalVar CallGraphNode::RemoveGlobalVarFromModule(CallGraphEntryNode* cg_node, + bool update_call_graph) { + CHECK(cg_node->empty() || (cg_node->IsRecursive() && cg_node->size() == 1)) + << "Cannot remove global var " << cg_node->GetNameHint() + << " from call graph, because it still calls " + << cg_node->size() << " other global functions"; + + if (update_call_graph) { + // Update the call graph by removing all edges that point to the node + // `cg_node`. + for (auto& it : *this) { + it.second->RemoveAllCallTo(cg_node); + } + } + GlobalVar gv = cg_node->GetGlobalVar(); + call_graph_.erase(gv); + // Update the IR module. + module->Remove(gv); + return gv; +} + +std::vector CallGraphNode::GetEntryGlobals() const { + std::vector ret; + // An entry function in Relay is a function that never called by other + // functions or only called by itself. + for (const auto& it : *this) { + if (it.second->GetRefCount() == 0 || it.second->IsRecursiveEntry()) { + ret.push_back(it.second.get()); + } + } + return ret; +} + +std::vector CallGraphNode::TopologicalOrder() const { + std::vector ret; + // Collect all entry nodes. + std::vector entries = GetEntryGlobals(); + CallGraphEntryNode::CallGraphEntryNodeSet visited; + + for (const auto& it : entries) { + // Keep tracking the nodes that have been visited. + auto topo = it->TopologicalOrder(&visited); + // Preprend the collected items. The intermeidate nodes that are shared by + // multiple entries are guaranteed to be collected when visiting the + // previous entries. Therefore, topological order remains. + ret.insert(ret.begin(), topo.begin(), topo.end()); + } + + // Find out the missing global functions if there are any to help debugging. + if (ret.size() != module->functions.size()) { + for (auto it : module->functions) { + if (visited.find((*this)[it.first]) == visited.end()) { + LOG(WARNING) << "Missing global:" << it.first->name_hint + << " with # refs = " << (*this)[it.first]->GetRefCount(); + } + } + LOG(FATAL) << "Expected " << module->functions.size() + << " globals, but received " + << ret.size(); + } + + return ret; +} + +// A BSF traverser is used to collect the nodes in a CallGraphEntryNode. The nodes +// that are visited by previous CallGraphEntryNode entries can be memoized. This +// helps us to make sure no entry will be visited multiple times when collecting +// the nodes for an entir call graph. +std::vector CallGraphEntryNode::TopologicalOrder( + CallGraphEntryNodeSet* visited) const { + std::vector ret; + std::vector current_nodes; + if (visited->find(this) == visited->end()) { + visited->emplace(this); + current_nodes.emplace_back(const_cast(this)); + } + + std::vector next_nodes; + while (!current_nodes.empty()) { + for (const auto& node : current_nodes) { + ret.push_back(node); + // Iterate through the called entries. + for (auto git = node->begin(); git != node->end(); ++git) { + if (visited->find(git->second) == visited->end()) { + next_nodes.push_back(git->second); + visited->emplace(git->second); + } + } + } + // Update the current level and clean the next level. + current_nodes = next_nodes; + next_nodes.clear(); + } + return ret; +} + +void CallGraphEntryNode::CleanCallGraphEntries() { + while (!called_globals_.empty()) { + // Decrement the reference counter + called_globals_.back().second->DecRef(); + called_globals_.pop_back(); + } +} + +inline void CallGraphEntryNode::AddCalledGlobal(CallGraphEntryNode* cg_node) { + called_globals_.emplace_back(global_, cg_node); + // Increment the reference to indicate that another call site is found for + // the callee in `cg_node`. + cg_node->IncRef(); + // Mark the global function as recursive if it calls itself. + if (global_ == cg_node->GetGlobalVar()) { + cg_node->is_recursive_ = true; + } +} + +// Remove an edge from the current global function to the callee. +void CallGraphEntryNode::RemoveCallTo(const GlobalVar& callee) { + for (auto it = begin();; ++it) { + CHECK(it != end()) << "Cannot find global function " + << callee->name_hint << " to remove!"; + if (it->second->GetGlobalVar() == callee) { + // Only remove one occurrence of the call site. + it->second->DecRef(); + *it = called_globals_.back(); + called_globals_.pop_back(); + return; + } + } +} + +// Remove all edges from the current global function to the callee. +void CallGraphEntryNode::RemoveAllCallTo(CallGraphEntryNode* callee) { + for (uint32_t i = 0, e = size(); i != e;) { + if (called_globals_[i].second == callee) { + callee->DecRef(); + called_globals_[i] = called_globals_.back(); + called_globals_.pop_back(); + --e; + } else { + ++i; + } + } + // Make sure all references to the callee are removed. + CHECK_EQ(callee->GetRefCount(), 0U) + << "All references to " << callee->GetNameHint() + << " should have been removed"; +} + +void CallGraphEntryNode::Print(std::ostream& os) const { + if (!global_.defined()) { + os << "GlobalVar is not defined\n"; + return; + } + + os << "Call graph node: " << global_->name_hint; + os << " at: " << this << ", #refs = " << GetRefCount() << "\n"; + + for (const auto& it : *this) { + os << " call site: <" << it.first->name_hint << "> calls "; + os << it.second->GetNameHint() << "\n"; + } + os << "\n"; +} + +std::ostream& operator<<(std::ostream& os, const CallGraph& cg) { + cg->Print(os); + return os; +} + +std::ostream& operator<<(std::ostream& os, const CallGraphEntryNode& cgn) { + cgn.Print(os); + return os; +} + +TVM_REGISTER_NODE_TYPE(CallGraphNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + CHECK(node); + p->stream << "CallGraph: \n" << GetRef(node); +}); + +TVM_REGISTER_GLOBAL("relay._analysis.CallGraph") +.set_body_typed([](IRModule module) { + return CallGraph(module); +}); + +TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraph") +.set_body_typed([](CallGraph call_graph) { + std::stringstream ss; + ss << call_graph; + return ss.str(); +}); + +TVM_REGISTER_GLOBAL("relay._analysis.GetModule") +.set_body_typed([](CallGraph call_graph) { + return call_graph->GetModule(); +}); + +TVM_REGISTER_GLOBAL("relay._analysis.GetCallGraphGlobalVar") +.set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + std::stringstream ss; + ss << *entry_node; + return ss.str(); +}); + +TVM_REGISTER_GLOBAL("relay._analysis.GetRefCountGlobalVar") +.set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + return static_cast(entry_node->GetRefCount()); +}); + +TVM_REGISTER_GLOBAL("relay._analysis.GetGlobalVarCallCount") +.set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + return static_cast(entry_node->size()); +}); + +TVM_REGISTER_GLOBAL("relay._analysis.IsRecursive") +.set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + return entry_node->IsRecursive(); +}); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/call_graph.h b/src/relay/pass/call_graph.h new file mode 100644 index 000000000000..7e1f23db4b80 --- /dev/null +++ b/src/relay/pass/call_graph.h @@ -0,0 +1,509 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/pass/call_graph.h + * \brief Define data structures for the call graph of a IRModule. It borrows + * the idea how LLVM constructs CallGraph. + * + * https://llvm.org/doxygen/CallGraph_8h_source.html + */ + +#ifndef TVM_RELAY_PASS_CALL_GRAPH_H_ +#define TVM_RELAY_PASS_CALL_GRAPH_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +class CallGraphEntryNode; +class CallGraph; + +class CallGraphNode : public Object { + using CallGraphMap = + std::unordered_map, ObjectHash, + ObjectEqual>; + // Create iterator alias for a CallGraphNode object. + using iterator = CallGraphMap::iterator; + using const_iterator = CallGraphMap::const_iterator; + + public: + /*! \brief The IR module for creating a CallGraphNode. */ + IRModule module; + + /*! \brief Default constructor. */ + CallGraphNode() {} + + void VisitAttrs(AttrVisitor* v) { + v->Visit("module", &module); + } + + /*! + * \brief Print the call graph. + * + * \param os The stream for printing. + */ + void Print(std::ostream& os) const; + + /*! \return The begin iterator. */ + iterator begin() { + return call_graph_.begin(); + } + /*! \return The end iterator. */ + iterator end() { + return call_graph_.end(); + } + /*! \return The begin iterator. */ + const_iterator begin() const { + return call_graph_.begin(); + } + /*! \return The end iterator. */ + const_iterator end() const { + return call_graph_.end(); + } + + /*! + * \brief Get an element from the CallGraphNode using a GlobalVar. + * + * \param gv The GlobalVar used for indexing. + * + * \return The fetched element. + */ + const CallGraphEntryNode* operator[](const GlobalVar& gv) const; + /*! + * \brief Get an element from the CallGraphNode using a GlobalVar. + * + * \param gv The GlobalVar used for indexing. + * + * \return The fetched element. + */ + CallGraphEntryNode* operator[](const GlobalVar& gv); + /*! + * \brief Get an element from the CallGraphNode using the global function name. + * + * \param gvar_name The global function name used for indexing. + * + * \return The fetched element. + */ + const CallGraphEntryNode* operator[](const std::string& gvar_name) const { + return (*this)[module->GetGlobalVar(gvar_name)]; + } + /*! + * \brief Get an element from the CallGraphNode using the global function name. + * + * \param gvar_name The global function name used for indexing. + * + * \return The fetched element. + */ + CallGraphEntryNode* operator[](const std::string& gvar_name) { + return (*this)[module->GetGlobalVar(gvar_name)]; + } + + /*! \brief Return the IR module. */ + IRModule GetModule() const { + return module; + } + + /*! + * \brief Get the entries/root nodes of CallGraphNode. + * + * Entry functions are never referenced by other functions. + * Note these functions can be recursive as well. + * + * \return The list of CallGraphEntryNode that represent entry nodes. + */ + std::vector GetEntryGlobals() const; + + /*! + * \brief Remove a GlobalVar in a given CallGraphEntryNode from the current + * IR module. + * + * \param cg_node The CallGraphEntryNode that contains a global function to be + * removed. + * \param update_call_graph Indicate if we will update the CallGraph as well + * since updating is costly. We are only able to remove a leaf function + * when update_call_graph is disabled because the edges pointing to + * functions being removed are not updated. + * + * \return The GlobalVar removed from the current module. + */ + GlobalVar RemoveGlobalVarFromModule(CallGraphEntryNode* cg_node, + bool update_call_graph = false); + + /*! + * \brief Lookup a GlobalVar for the CallGraphNode. It creates an entry for + * the GlobalVar if it doesn't exist. + * + * \param gv The GlobalVar for query. + * + * \return The queried entry. + */ + CallGraphEntryNode* LookupGlobalVar(const GlobalVar& gv); + + /*! + * \brief Get the entries from the CallGraphNode in the topological order. + * + * This is useful for various module-level optimizations/analysis. For example, + * inlining requires the correct order of the functions being processed, i.e. + * callee should be always handled before callers. + * + * \return The list of collected entries that are sorted in the topological order. + */ + std::vector TopologicalOrder() const; + + static constexpr const char* _type_key = "relay.CallGraph"; + TVM_DECLARE_FINAL_OBJECT_INFO(CallGraphNode, Object); + + private: + /*! + * \brief Create a CallGraphEntryNode for a global function and add it to the + * CallGraphNode. + * + * \param gv The global var. + * \param func The global function corresponding to `gv`. + */ + void AddToCallGraph(const GlobalVar& gv, const Function& func); + + /*! \brief A record contains GlobalVar to CallGraphEntryNode mapping. */ + CallGraphMap call_graph_; + + friend CallGraph; +}; + +/*! + * \brief The class that represents the call graph of a Relay IR module. It also + * provides a variety of utility functions for users to query, view, and update + * a call graph. + */ +class CallGraph : public ObjectRef { + using CallGraphMap = + std::unordered_map, ObjectHash, + ObjectEqual>; + // Create iterator alias for a CallGraph object. + using iterator = CallGraphMap::iterator; + using const_iterator = CallGraphMap::const_iterator; + + public: + /*! + * \brief Construct a CallGraph from a IR module. + * + * \param module The IR module + */ + explicit CallGraph(IRModule module); + + /*! + * \brief Construct from an object pointer. + * \param n The object pointer. + */ + explicit CallGraph(ObjectPtr n) : ObjectRef(n) {} + + /*! \return The begin iterator. */ + iterator begin() { + auto* n = operator->(); + CHECK(n); + return n->begin(); + } + /*! \return The end iterator. */ + iterator end() { + auto* n = operator->(); + CHECK(n); + return n->end(); + } + /*! \return The begin iterator. */ + const_iterator begin() const { + const auto* n = operator->(); + CHECK(n); + return n->begin(); + } + /*! \return The end iterator. */ + const_iterator end() const { + const auto* n = operator->(); + CHECK(n); + return n->end(); + } + + /*! + * \brief Get an element from the CallGraph using a GlobalVar. + * + * \param gv The GlobalVar used for indexing. + * + * \return The fetched element. + */ + const CallGraphEntryNode* operator[](const GlobalVar& gv) const { + const auto* n = operator->(); + CHECK(n); + return (*n)[gv]; + } + /*! + * \brief Get an element from the CallGraph using a GlobalVar. + * + * \param gv The GlobalVar used for indexing. + * + * \return The fetched element. + */ + CallGraphEntryNode* operator[](const GlobalVar& gv) { + auto* n = operator->(); + CHECK(n); + return (*n)[gv]; + } + /*! + * \brief Get an element from the CallGraph using the global function name. + * + * \param gvar_name The global function name used for indexing. + * + * \return The fetched element. + */ + const CallGraphEntryNode* operator[](const std::string& gvar_name) const { + const auto* n = operator->(); + CHECK(n); + return (*n)[gvar_name]; + } + /*! + * \brief Get an element from the CallGraph using the global function name. + * + * \param gvar_name The global function name used for indexing. + * + * \return The fetched element. + */ + CallGraphEntryNode* operator[](const std::string& gvar_name) { + auto* n = operator->(); + CHECK(n); + return (*n)[gvar_name]; + } + + /*! \return mutable pointers to the node. */ + CallGraphNode* operator->() const { + auto* ptr = get_mutable(); + CHECK(ptr != nullptr); + return static_cast(ptr); + } + + private: + /*! \brief Overload the << operator to print a call graph. */ + friend std::ostream& operator<<(std::ostream& os, const CallGraph&); +}; + +/*! + * \brief A node in the call graph. It maintains the edges from a caller to + * all callees. + */ +class CallGraphEntryNode { + public: + using CallGraphEntry = std::pair; + using CallGraphEntryVector = std::vector; + using CallGraphEntryNodeSet = std::unordered_set; + // Create iterator alias for a CallGraphEntryNode object. + using iterator = std::vector::iterator; + using const_iterator = std::vector::const_iterator; + + /*! + * \brief Construct from a GlobalVar. + * + * \param gv The GlobalVar to create a CallGraphEntryNode. + */ + explicit CallGraphEntryNode(const GlobalVar& gv) : global_(gv) {} + /*! + * \brief Delete copy constructor. + */ + CallGraphEntryNode(const CallGraphEntryNode&) = delete; + /*! \brief Delete assignment. */ + CallGraphEntryNode& operator=(const CallGraphEntryNode&) = delete; + + /*! \return The begin iterator */ + iterator begin() { + return called_globals_.begin(); + } + /*! \return The end iterator */ + iterator end() { + return called_globals_.end(); + } + /*! \return The const begin iterator */ + const_iterator begin() const { + return called_globals_.begin(); + } + /*! \return The const end iterator */ + const_iterator end() const { + return called_globals_.end(); + } + + /*! + * \brief Return if the list of called nodes is empty. + * + * \return true if the list is empty. Otherwise, false. + */ + bool empty() const { + return called_globals_.empty(); + } + + /*! + * \brief Return the size of the list that represents the nodes are called by + * the current node. + * + * \return The number of called nodes. + */ + uint32_t size() const { + return static_cast(called_globals_.size()); + } + + /*! + * \brief Fetch the i-th CallGraphEntryNode from the list of nodes that are called + * by the current function. + * + * \param i The index. + * + * \return The fetched CallGraphEntryNode. + */ + CallGraphEntryNode* operator[](size_t i) const { + CHECK_LT(i, called_globals_.size()) << "Invalid Index"; + return called_globals_[i].second; + } + + /*! + * \brief Print the call graph that is stemmed from the current CallGraphEntryNode. + * + * \param os The stream for printing. + */ + void Print(std::ostream& os) const; + + /*! + * \brief Return the number of times the global function is referenced. + * + * \return The count. + */ + uint32_t GetRefCount() const { + return ref_cnt_; + } + + /*! + * \brief Return the GlobalVar stored in the current CallGraphEntryNode. + * + * \return The GlobalVar. + */ + GlobalVar GetGlobalVar() const { + return global_; + } + + /*! + * \brief Return the name hint of the GlobalVar stored in the CallGraphEntryNode. + * + * \return The name hint of the global function. + */ + std::string GetNameHint() const { + return global_->name_hint; + } + + /*! + * \brief Return if the global function corresponding to the current + * CallGraphEntryNode is a recursive function. + * + * \return true if it is recursive. Otherwise, false. + */ + bool IsRecursive() const { + return is_recursive_; + } + + /*! + * \brief Return if the global function corresponding to the current + * CallGraphEntryNode is both a recursive function and an entry function. This type + * of function only has one reference which is called by itself. + * + * \return true if it is both a recursive function and an entry. Otherwise, false. + */ + bool IsRecursiveEntry() const { + return GetRefCount() == 1 && IsRecursive(); + } + + /*! + * \brief Return the topological order of the CallGraphEntryNode. + * + * \param visited A set of CallGraphEntryNode objects that have been visited. + * + * \return The list of CallGraphEntryNode that is represented in topological order. + */ + std::vector TopologicalOrder( + CallGraphEntryNodeSet* visited = new CallGraphEntryNodeSet()) const; + + /*! + * \brief Remove all edges from the current CallGraphEntryNode to any global + * function it calls. + */ + void CleanCallGraphEntries(); + + /*! + * \brief Add a node to the list of nodes that are being called by the current + * global function. + * + * \param cg_node The CallGraphEntryNode that will be added to the call list. + */ + void AddCalledGlobal(CallGraphEntryNode* cg_node); + + /*! + * \brief Remove a call edge to the global function from the current + * function. + * + * \param callee The function that is being called. + */ + void RemoveCallTo(const GlobalVar& callee); + + /*! + * \brief Remove all the edges that represent that calls to the global function + * stored in a given CallGraphEntryNode. + * + * \param callee The function that is being called. + */ + void RemoveAllCallTo(CallGraphEntryNode* callee); + + private: + /*! \brief Decrement the reference counter by 1. */ + void DecRef() { + CHECK_GT(ref_cnt_, 0); + --ref_cnt_; + } + /*! \brief Increment the reference counter by 1. */ + void IncRef() { ++ref_cnt_; } + + /*! + * \brief Mark if the global function stored in the CallGraphEntryNode is + * recursive function. + */ + bool is_recursive_{false}; + /*! \brief Count the number of times the global function is referenced. */ + uint32_t ref_cnt_{0}; + /*! \brief The GlobalVar stored in the current CallGraphEntryNode. */ + GlobalVar global_; + /*! \brief The list of entries called by the current CallGraphEntryNode. */ + CallGraphEntryVector called_globals_; + + friend class CallGraph; + /*! \brief Overload the << operator to print a call graph node. */ + friend std::ostream& operator<<(std::ostream& os, const CallGraphEntryNode&); +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_CALL_GRAPH_H_ diff --git a/tests/python/relay/test_call_graph.py b/tests/python/relay/test_call_graph.py new file mode 100644 index 000000000000..4d82c5c2ce22 --- /dev/null +++ b/tests/python/relay/test_call_graph.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +import pytest +import tvm +from tvm import relay + + +def test_callgraph_construct(): + mod = tvm.IRModule({}) + x = relay.var("x", shape=(2, 3)) + y = relay.var("y", shape=(2, 3)) + mod["g1"] = relay.Function([x, y], x + y) + call_graph = relay.CallGraph(mod) + assert "g1" in str(call_graph) + assert relay.alpha_equal(mod, call_graph.module) + + +def test_print_element(): + mod = tvm.IRModule({}) + x0 = relay.var("x0", shape=(2, 3)) + y0 = relay.var("y0", shape=(2, 3)) + mod["g0"] = relay.Function([x0, y0], x0 + y0) + x1 = relay.var("x1", shape=(2, 3)) + y1 = relay.var("y1", shape=(2, 3)) + mod["g1"] = relay.Function([x1, y1], x1 - y1) + call_graph = relay.CallGraph(mod) + + assert "#refs = 0" in str(call_graph["g0"]) + assert "#refs = 0" in str(call_graph["g1"]) + + +def test_global_call_count(): + mod = tvm.IRModule({}) + x0 = relay.var("x0", shape=(2, 3)) + y0 = relay.var("y0", shape=(2, 3)) + g0 = relay.GlobalVar("g0") + mod[g0] = relay.Function([x0, y0], x0 + y0) + x1 = relay.var("x1", shape=(2, 3)) + y1 = relay.var("y1", shape=(2, 3)) + g1 = relay.GlobalVar("g1") + mod[g1] = relay.Function([x1, y1], g0(x1, y1)) + call_graph = relay.CallGraph(mod) + + p0 = relay.var("p0", shape=(2, 3)) + p1 = relay.var("p1", shape=(2, 3)) + func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1)) + mod["main"] = func + call_graph = relay.CallGraph(mod) + + assert call_graph.global_call_count(g0) == 0 + assert call_graph.global_call_count(g1) == 1 + assert call_graph.global_call_count("main") == 2 + + +def test_ref_count(): + mod = tvm.IRModule({}) + x0 = relay.var("x0", shape=(2, 3)) + y0 = relay.var("y0", shape=(2, 3)) + g0 = relay.GlobalVar("g0") + mod[g0] = relay.Function([x0, y0], x0 + y0) + x1 = relay.var("x1", shape=(2, 3)) + y1 = relay.var("y1", shape=(2, 3)) + g1 = relay.GlobalVar("g1") + mod[g1] = relay.Function([x1, y1], x1 - y1) + call_graph = relay.CallGraph(mod) + + p0 = relay.var("p0", shape=(2, 3)) + p1 = relay.var("p1", shape=(2, 3)) + func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1)) + mod["main"] = func + call_graph = relay.CallGraph(mod) + + assert call_graph.ref_count(g0) == 1 + assert call_graph.ref_count(g1) == 1 + assert call_graph.ref_count("main") == 0 + + +def test_nested_ref(): + mod = tvm.IRModule({}) + x0 = relay.var("x0", shape=(2, 3)) + y0 = relay.var("y0", shape=(2, 3)) + g0 = relay.GlobalVar("g0") + mod[g0] = relay.Function([x0, y0], x0 + y0) + x1 = relay.var("x1", shape=(2, 3)) + y1 = relay.var("y1", shape=(2, 3)) + g1 = relay.GlobalVar("g1") + mod[g1] = relay.Function([x1, y1], g0(x1, y1)) + call_graph = relay.CallGraph(mod) + + p0 = relay.var("p0", shape=(2, 3)) + p1 = relay.var("p1", shape=(2, 3)) + func = relay.Function([p0, p1], g0(p0, p1) * g1(p0, p1)) + mod["main"] = func + call_graph = relay.CallGraph(mod) + + assert call_graph.ref_count(g0) == 2 + assert call_graph.ref_count(g1) == 1 + assert call_graph.ref_count("main") == 0 + + +def test_recursive_func(): + mod = tvm.IRModule({}) + + x = relay.var('x', shape=[], dtype='int32') + fn0 = relay.Function([x], x) + gx = relay.GlobalVar("gx") + mod[gx] = fn0 + + sum_up = relay.GlobalVar('sum_up') + i = relay.var('i', shape=[], dtype='int32') + sb = relay.ScopeBuilder() + with sb.if_scope(relay.equal(i, relay.const(0, dtype='int32'))): + sb.ret(i) + with sb.else_scope(): + one_less = relay.subtract(i, relay.const(1, dtype='int32')) + global_call = gx(i) + rec_call = relay.Call(sum_up, [one_less]) + global_call + sb.ret(relay.add(rec_call, i)) + func = relay.Function([i], + sb.get(), + ret_type=relay.TensorType([], 'int32')) + func = func.set_attribute("Compiler", tvm.tir.StringImm("a")) + mod[sum_up] = func + iarg = relay.var('i', shape=[], dtype='int32') + mod["main"] = relay.Function([iarg], sum_up(iarg)) + call_graph = relay.CallGraph(mod) + + assert call_graph.is_recursive(sum_up) + assert call_graph.ref_count(sum_up) == 2 + assert call_graph.ref_count(gx) == 1 + assert call_graph.ref_count("main") == 0 + + +if __name__ == "__main__": + pytest.main() From a7f0d6bee33d8ccbe8301cf139735e29ca795f36 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 25 Feb 2020 04:43:34 +0000 Subject: [PATCH 2/3] CallGraphEntryNode->CallGraphEntry, __getitem__->print_var --- python/tvm/relay/call_graph.py | 17 +++-- src/relay/pass/call_graph.cc | 56 +++++++------- src/relay/pass/call_graph.h | 106 +++++++++++++------------- tests/python/relay/test_call_graph.py | 4 +- 4 files changed, 92 insertions(+), 91 deletions(-) diff --git a/python/tvm/relay/call_graph.py b/python/tvm/relay/call_graph.py index 104ccda8d585..8206f5dccd4c 100644 --- a/python/tvm/relay/call_graph.py +++ b/python/tvm/relay/call_graph.py @@ -87,7 +87,8 @@ def global_call_count(self, var): return _analysis.GetGlobalVarCallCount(self, var) def is_recursive(self, var): - """Return the number of global function calls from a given global var. + """Return if the function corresponding to a var is a recursive + function. Parameters ---------- @@ -122,12 +123,8 @@ def _get_global_var(self, var): else: raise TypeError("var should be either a string or GlobalVar") - def __str__(self): - """Print the call graph in the topological order.""" - return _analysis.PrintCallGraph(self) - - def __getitem__(self, var): - """Lookup a call graph of a global function by name or by variable. + def print_var(self, var): + """Print a call graph of a global function by name or by variable. Parameters ---------- @@ -140,4 +137,8 @@ def __getitem__(self, var): The call graph represented in string. """ var = self._get_global_var(var) - return _analysis.GetCallGraphGlobalVar(self, var) + return _analysis.PrintCallGraphGlobalVar(self, var) + + def __str__(self): + """Print the call graph in the topological order.""" + return _analysis.PrintCallGraph(self) diff --git a/src/relay/pass/call_graph.cc b/src/relay/pass/call_graph.cc index 5a4b6a91c04a..42fdf20c726c 100644 --- a/src/relay/pass/call_graph.cc +++ b/src/relay/pass/call_graph.cc @@ -52,7 +52,7 @@ CallGraph::CallGraph(IRModule module) { void CallGraphNode::AddToCallGraph(const GlobalVar& gv, const Function& func) { CHECK(func.defined() && gv.defined()); // Add the current global function as an entry to the call grpah. - CallGraphEntryNode* cg_node = LookupGlobalVar(gv); + CallGraphEntry* cg_node = LookupGlobalVar(gv); // Only GlobalVar nodes need to be handled in a function. It indicates that // the global function of a callee is called by the function that is being @@ -70,14 +70,14 @@ void CallGraphNode::AddToCallGraph(const GlobalVar& gv, const Function& func) { }); } -const CallGraphEntryNode* CallGraphNode::operator[](const GlobalVar& gv) const { +const CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) const { const_iterator cit = call_graph_.find(gv); CHECK(cit != call_graph_.end()) << "GlobalVar " << gv->name_hint << " not found in the call graph!"; return cit->second.get(); } -CallGraphEntryNode* CallGraphNode::operator[](const GlobalVar& gv) { +CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) { const_iterator cit = call_graph_.find(gv); CHECK(cit != call_graph_.end()) << "GlobalVar " << gv->name_hint << " not found in the call graph!"; @@ -86,7 +86,7 @@ CallGraphEntryNode* CallGraphNode::operator[](const GlobalVar& gv) { // Query the existence of a GlobalVar in the call graph. It creates an entry if // there is no such a node available. -CallGraphEntryNode* CallGraphNode::LookupGlobalVar(const GlobalVar& gv) { +CallGraphEntry* CallGraphNode::LookupGlobalVar(const GlobalVar& gv) { CHECK(gv.defined()); // This inserts an element to the call graph if it is not there yet. @@ -97,19 +97,19 @@ CallGraphEntryNode* CallGraphNode::LookupGlobalVar(const GlobalVar& gv) { << "GlobalVar " << gv->name_hint << " not found in the current ir module"; // Create the node for the inserted entry. - call_graph_node = std::unique_ptr(new CallGraphEntryNode(gv)); + call_graph_node = std::unique_ptr(new CallGraphEntry(gv)); return call_graph_node.get(); } void CallGraphNode::Print(std::ostream& os) const { // Print the call graph in the topological order. - std::vector nodes = TopologicalOrder(); + std::vector nodes = TopologicalOrder(); for (const auto* cgn : nodes) { cgn->Print(os); } } -GlobalVar CallGraphNode::RemoveGlobalVarFromModule(CallGraphEntryNode* cg_node, +GlobalVar CallGraphNode::RemoveGlobalVarFromModule(CallGraphEntry* cg_node, bool update_call_graph) { CHECK(cg_node->empty() || (cg_node->IsRecursive() && cg_node->size() == 1)) << "Cannot remove global var " << cg_node->GetNameHint() @@ -130,8 +130,8 @@ GlobalVar CallGraphNode::RemoveGlobalVarFromModule(CallGraphEntryNode* cg_node, return gv; } -std::vector CallGraphNode::GetEntryGlobals() const { - std::vector ret; +std::vector CallGraphNode::GetEntryGlobals() const { + std::vector ret; // An entry function in Relay is a function that never called by other // functions or only called by itself. for (const auto& it : *this) { @@ -142,11 +142,11 @@ std::vector CallGraphNode::GetEntryGlobals() const { return ret; } -std::vector CallGraphNode::TopologicalOrder() const { - std::vector ret; +std::vector CallGraphNode::TopologicalOrder() const { + std::vector ret; // Collect all entry nodes. - std::vector entries = GetEntryGlobals(); - CallGraphEntryNode::CallGraphEntryNodeSet visited; + std::vector entries = GetEntryGlobals(); + CallGraphEntry::CallGraphEntrySet visited; for (const auto& it : entries) { // Keep tracking the nodes that have been visited. @@ -173,20 +173,20 @@ std::vector CallGraphNode::TopologicalOrder() const { return ret; } -// A BSF traverser is used to collect the nodes in a CallGraphEntryNode. The nodes -// that are visited by previous CallGraphEntryNode entries can be memoized. This +// A BSF traverser is used to collect the nodes in a CallGraphEntry. The nodes +// that are visited by previous CallGraphEntry entries can be memoized. This // helps us to make sure no entry will be visited multiple times when collecting // the nodes for an entir call graph. -std::vector CallGraphEntryNode::TopologicalOrder( - CallGraphEntryNodeSet* visited) const { - std::vector ret; - std::vector current_nodes; +std::vector CallGraphEntry::TopologicalOrder( + CallGraphEntrySet* visited) const { + std::vector ret; + std::vector current_nodes; if (visited->find(this) == visited->end()) { visited->emplace(this); - current_nodes.emplace_back(const_cast(this)); + current_nodes.emplace_back(const_cast(this)); } - std::vector next_nodes; + std::vector next_nodes; while (!current_nodes.empty()) { for (const auto& node : current_nodes) { ret.push_back(node); @@ -205,7 +205,7 @@ std::vector CallGraphEntryNode::TopologicalOrder( return ret; } -void CallGraphEntryNode::CleanCallGraphEntries() { +void CallGraphEntry::CleanCallGraphEntries() { while (!called_globals_.empty()) { // Decrement the reference counter called_globals_.back().second->DecRef(); @@ -213,7 +213,7 @@ void CallGraphEntryNode::CleanCallGraphEntries() { } } -inline void CallGraphEntryNode::AddCalledGlobal(CallGraphEntryNode* cg_node) { +inline void CallGraphEntry::AddCalledGlobal(CallGraphEntry* cg_node) { called_globals_.emplace_back(global_, cg_node); // Increment the reference to indicate that another call site is found for // the callee in `cg_node`. @@ -225,7 +225,7 @@ inline void CallGraphEntryNode::AddCalledGlobal(CallGraphEntryNode* cg_node) { } // Remove an edge from the current global function to the callee. -void CallGraphEntryNode::RemoveCallTo(const GlobalVar& callee) { +void CallGraphEntry::RemoveCallTo(const GlobalVar& callee) { for (auto it = begin();; ++it) { CHECK(it != end()) << "Cannot find global function " << callee->name_hint << " to remove!"; @@ -240,7 +240,7 @@ void CallGraphEntryNode::RemoveCallTo(const GlobalVar& callee) { } // Remove all edges from the current global function to the callee. -void CallGraphEntryNode::RemoveAllCallTo(CallGraphEntryNode* callee) { +void CallGraphEntry::RemoveAllCallTo(CallGraphEntry* callee) { for (uint32_t i = 0, e = size(); i != e;) { if (called_globals_[i].second == callee) { callee->DecRef(); @@ -257,7 +257,7 @@ void CallGraphEntryNode::RemoveAllCallTo(CallGraphEntryNode* callee) { << " should have been removed"; } -void CallGraphEntryNode::Print(std::ostream& os) const { +void CallGraphEntry::Print(std::ostream& os) const { if (!global_.defined()) { os << "GlobalVar is not defined\n"; return; @@ -278,7 +278,7 @@ std::ostream& operator<<(std::ostream& os, const CallGraph& cg) { return os; } -std::ostream& operator<<(std::ostream& os, const CallGraphEntryNode& cgn) { +std::ostream& operator<<(std::ostream& os, const CallGraphEntry& cgn) { cgn.Print(os); return os; } @@ -309,7 +309,7 @@ TVM_REGISTER_GLOBAL("relay._analysis.GetModule") return call_graph->GetModule(); }); -TVM_REGISTER_GLOBAL("relay._analysis.GetCallGraphGlobalVar") +TVM_REGISTER_GLOBAL("relay._analysis.PrintCallGraphGlobalVar") .set_body_typed([](CallGraph call_graph, GlobalVar var) { const auto* entry_node = call_graph[var]; std::stringstream ss; diff --git a/src/relay/pass/call_graph.h b/src/relay/pass/call_graph.h index 7e1f23db4b80..340ee30bc5d2 100644 --- a/src/relay/pass/call_graph.h +++ b/src/relay/pass/call_graph.h @@ -41,12 +41,12 @@ namespace tvm { namespace relay { -class CallGraphEntryNode; +class CallGraphEntry; class CallGraph; class CallGraphNode : public Object { using CallGraphMap = - std::unordered_map, ObjectHash, + std::unordered_map, ObjectHash, ObjectEqual>; // Create iterator alias for a CallGraphNode object. using iterator = CallGraphMap::iterator; @@ -94,7 +94,7 @@ class CallGraphNode : public Object { * * \return The fetched element. */ - const CallGraphEntryNode* operator[](const GlobalVar& gv) const; + const CallGraphEntry* operator[](const GlobalVar& gv) const; /*! * \brief Get an element from the CallGraphNode using a GlobalVar. * @@ -102,7 +102,7 @@ class CallGraphNode : public Object { * * \return The fetched element. */ - CallGraphEntryNode* operator[](const GlobalVar& gv); + CallGraphEntry* operator[](const GlobalVar& gv); /*! * \brief Get an element from the CallGraphNode using the global function name. * @@ -110,7 +110,7 @@ class CallGraphNode : public Object { * * \return The fetched element. */ - const CallGraphEntryNode* operator[](const std::string& gvar_name) const { + const CallGraphEntry* operator[](const std::string& gvar_name) const { return (*this)[module->GetGlobalVar(gvar_name)]; } /*! @@ -120,7 +120,7 @@ class CallGraphNode : public Object { * * \return The fetched element. */ - CallGraphEntryNode* operator[](const std::string& gvar_name) { + CallGraphEntry* operator[](const std::string& gvar_name) { return (*this)[module->GetGlobalVar(gvar_name)]; } @@ -135,15 +135,15 @@ class CallGraphNode : public Object { * Entry functions are never referenced by other functions. * Note these functions can be recursive as well. * - * \return The list of CallGraphEntryNode that represent entry nodes. + * \return The list of CallGraphEntry that represent entry nodes. */ - std::vector GetEntryGlobals() const; + std::vector GetEntryGlobals() const; /*! - * \brief Remove a GlobalVar in a given CallGraphEntryNode from the current + * \brief Remove a GlobalVar in a given CallGraphEntry from the current * IR module. * - * \param cg_node The CallGraphEntryNode that contains a global function to be + * \param cg_node The CallGraphEntry that contains a global function to be * removed. * \param update_call_graph Indicate if we will update the CallGraph as well * since updating is costly. We are only able to remove a leaf function @@ -152,7 +152,7 @@ class CallGraphNode : public Object { * * \return The GlobalVar removed from the current module. */ - GlobalVar RemoveGlobalVarFromModule(CallGraphEntryNode* cg_node, + GlobalVar RemoveGlobalVarFromModule(CallGraphEntry* cg_node, bool update_call_graph = false); /*! @@ -163,7 +163,7 @@ class CallGraphNode : public Object { * * \return The queried entry. */ - CallGraphEntryNode* LookupGlobalVar(const GlobalVar& gv); + CallGraphEntry* LookupGlobalVar(const GlobalVar& gv); /*! * \brief Get the entries from the CallGraphNode in the topological order. @@ -174,14 +174,14 @@ class CallGraphNode : public Object { * * \return The list of collected entries that are sorted in the topological order. */ - std::vector TopologicalOrder() const; + std::vector TopologicalOrder() const; static constexpr const char* _type_key = "relay.CallGraph"; TVM_DECLARE_FINAL_OBJECT_INFO(CallGraphNode, Object); private: /*! - * \brief Create a CallGraphEntryNode for a global function and add it to the + * \brief Create a CallGraphEntry for a global function and add it to the * CallGraphNode. * * \param gv The global var. @@ -189,7 +189,7 @@ class CallGraphNode : public Object { */ void AddToCallGraph(const GlobalVar& gv, const Function& func); - /*! \brief A record contains GlobalVar to CallGraphEntryNode mapping. */ + /*! \brief A record contains GlobalVar to CallGraphEntry mapping. */ CallGraphMap call_graph_; friend CallGraph; @@ -202,7 +202,7 @@ class CallGraphNode : public Object { */ class CallGraph : public ObjectRef { using CallGraphMap = - std::unordered_map, ObjectHash, + std::unordered_map, ObjectHash, ObjectEqual>; // Create iterator alias for a CallGraph object. using iterator = CallGraphMap::iterator; @@ -254,7 +254,7 @@ class CallGraph : public ObjectRef { * * \return The fetched element. */ - const CallGraphEntryNode* operator[](const GlobalVar& gv) const { + const CallGraphEntry* operator[](const GlobalVar& gv) const { const auto* n = operator->(); CHECK(n); return (*n)[gv]; @@ -266,7 +266,7 @@ class CallGraph : public ObjectRef { * * \return The fetched element. */ - CallGraphEntryNode* operator[](const GlobalVar& gv) { + CallGraphEntry* operator[](const GlobalVar& gv) { auto* n = operator->(); CHECK(n); return (*n)[gv]; @@ -278,7 +278,7 @@ class CallGraph : public ObjectRef { * * \return The fetched element. */ - const CallGraphEntryNode* operator[](const std::string& gvar_name) const { + const CallGraphEntry* operator[](const std::string& gvar_name) const { const auto* n = operator->(); CHECK(n); return (*n)[gvar_name]; @@ -290,7 +290,7 @@ class CallGraph : public ObjectRef { * * \return The fetched element. */ - CallGraphEntryNode* operator[](const std::string& gvar_name) { + CallGraphEntry* operator[](const std::string& gvar_name) { auto* n = operator->(); CHECK(n); return (*n)[gvar_name]; @@ -312,27 +312,27 @@ class CallGraph : public ObjectRef { * \brief A node in the call graph. It maintains the edges from a caller to * all callees. */ -class CallGraphEntryNode { +class CallGraphEntry { public: - using CallGraphEntry = std::pair; - using CallGraphEntryVector = std::vector; - using CallGraphEntryNodeSet = std::unordered_set; - // Create iterator alias for a CallGraphEntryNode object. - using iterator = std::vector::iterator; - using const_iterator = std::vector::const_iterator; + using CallGraphEntryPair = std::pair; + using CallGraphEntryVector = std::vector; + using CallGraphEntrySet = std::unordered_set; + // Create iterator alias for a CallGraphEntry object. + using iterator = std::vector::iterator; + using const_iterator = std::vector::const_iterator; /*! * \brief Construct from a GlobalVar. * - * \param gv The GlobalVar to create a CallGraphEntryNode. + * \param gv The GlobalVar to create a CallGraphEntry. */ - explicit CallGraphEntryNode(const GlobalVar& gv) : global_(gv) {} + explicit CallGraphEntry(const GlobalVar& gv) : global_(gv) {} /*! * \brief Delete copy constructor. */ - CallGraphEntryNode(const CallGraphEntryNode&) = delete; + CallGraphEntry(const CallGraphEntry&) = delete; /*! \brief Delete assignment. */ - CallGraphEntryNode& operator=(const CallGraphEntryNode&) = delete; + CallGraphEntry& operator=(const CallGraphEntry&) = delete; /*! \return The begin iterator */ iterator begin() { @@ -371,20 +371,20 @@ class CallGraphEntryNode { } /*! - * \brief Fetch the i-th CallGraphEntryNode from the list of nodes that are called + * \brief Fetch the i-th CallGraphEntry from the list of nodes that are called * by the current function. * * \param i The index. * - * \return The fetched CallGraphEntryNode. + * \return The fetched CallGraphEntry. */ - CallGraphEntryNode* operator[](size_t i) const { + CallGraphEntry* operator[](size_t i) const { CHECK_LT(i, called_globals_.size()) << "Invalid Index"; return called_globals_[i].second; } /*! - * \brief Print the call graph that is stemmed from the current CallGraphEntryNode. + * \brief Print the call graph that is stemmed from the current CallGraphEntry. * * \param os The stream for printing. */ @@ -400,7 +400,7 @@ class CallGraphEntryNode { } /*! - * \brief Return the GlobalVar stored in the current CallGraphEntryNode. + * \brief Return the GlobalVar stored in the current CallGraphEntry. * * \return The GlobalVar. */ @@ -409,7 +409,7 @@ class CallGraphEntryNode { } /*! - * \brief Return the name hint of the GlobalVar stored in the CallGraphEntryNode. + * \brief Return the name hint of the GlobalVar stored in the CallGraphEntry. * * \return The name hint of the global function. */ @@ -419,7 +419,7 @@ class CallGraphEntryNode { /*! * \brief Return if the global function corresponding to the current - * CallGraphEntryNode is a recursive function. + * CallGraphEntry is a recursive function. * * \return true if it is recursive. Otherwise, false. */ @@ -429,7 +429,7 @@ class CallGraphEntryNode { /*! * \brief Return if the global function corresponding to the current - * CallGraphEntryNode is both a recursive function and an entry function. This type + * CallGraphEntry is both a recursive function and an entry function. This type * of function only has one reference which is called by itself. * * \return true if it is both a recursive function and an entry. Otherwise, false. @@ -439,17 +439,17 @@ class CallGraphEntryNode { } /*! - * \brief Return the topological order of the CallGraphEntryNode. + * \brief Return the topological order of the CallGraphEntry. * - * \param visited A set of CallGraphEntryNode objects that have been visited. + * \param visited A set of CallGraphEntry objects that have been visited. * - * \return The list of CallGraphEntryNode that is represented in topological order. + * \return The list of CallGraphEntry that is represented in topological order. */ - std::vector TopologicalOrder( - CallGraphEntryNodeSet* visited = new CallGraphEntryNodeSet()) const; + std::vector TopologicalOrder( + CallGraphEntrySet* visited = new CallGraphEntrySet()) const; /*! - * \brief Remove all edges from the current CallGraphEntryNode to any global + * \brief Remove all edges from the current CallGraphEntry to any global * function it calls. */ void CleanCallGraphEntries(); @@ -458,9 +458,9 @@ class CallGraphEntryNode { * \brief Add a node to the list of nodes that are being called by the current * global function. * - * \param cg_node The CallGraphEntryNode that will be added to the call list. + * \param cg_node The CallGraphEntry that will be added to the call list. */ - void AddCalledGlobal(CallGraphEntryNode* cg_node); + void AddCalledGlobal(CallGraphEntry* cg_node); /*! * \brief Remove a call edge to the global function from the current @@ -472,11 +472,11 @@ class CallGraphEntryNode { /*! * \brief Remove all the edges that represent that calls to the global function - * stored in a given CallGraphEntryNode. + * stored in a given CallGraphEntry. * * \param callee The function that is being called. */ - void RemoveAllCallTo(CallGraphEntryNode* callee); + void RemoveAllCallTo(CallGraphEntry* callee); private: /*! \brief Decrement the reference counter by 1. */ @@ -488,20 +488,20 @@ class CallGraphEntryNode { void IncRef() { ++ref_cnt_; } /*! - * \brief Mark if the global function stored in the CallGraphEntryNode is + * \brief Mark if the global function stored in the CallGraphEntry is * recursive function. */ bool is_recursive_{false}; /*! \brief Count the number of times the global function is referenced. */ uint32_t ref_cnt_{0}; - /*! \brief The GlobalVar stored in the current CallGraphEntryNode. */ + /*! \brief The GlobalVar stored in the current CallGraphEntry. */ GlobalVar global_; - /*! \brief The list of entries called by the current CallGraphEntryNode. */ + /*! \brief The list of entries called by the current CallGraphEntry. */ CallGraphEntryVector called_globals_; friend class CallGraph; /*! \brief Overload the << operator to print a call graph node. */ - friend std::ostream& operator<<(std::ostream& os, const CallGraphEntryNode&); + friend std::ostream& operator<<(std::ostream& os, const CallGraphEntry&); }; } // namespace relay diff --git a/tests/python/relay/test_call_graph.py b/tests/python/relay/test_call_graph.py index 4d82c5c2ce22..fbbda678b102 100644 --- a/tests/python/relay/test_call_graph.py +++ b/tests/python/relay/test_call_graph.py @@ -40,8 +40,8 @@ def test_print_element(): mod["g1"] = relay.Function([x1, y1], x1 - y1) call_graph = relay.CallGraph(mod) - assert "#refs = 0" in str(call_graph["g0"]) - assert "#refs = 0" in str(call_graph["g1"]) + assert "#refs = 0" in str(call_graph.print_var("g0")) + assert "#refs = 0" in str(call_graph.print_var("g1")) def test_global_call_count(): From 4460950cbcaa457c54c912972a6f07a2d6a2693c Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 25 Feb 2020 23:36:14 +0000 Subject: [PATCH 3/3] fix typos --- src/relay/pass/call_graph.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/relay/pass/call_graph.cc b/src/relay/pass/call_graph.cc index 42fdf20c726c..6b82801776dd 100644 --- a/src/relay/pass/call_graph.cc +++ b/src/relay/pass/call_graph.cc @@ -85,7 +85,7 @@ CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) { } // Query the existence of a GlobalVar in the call graph. It creates an entry if -// there is no such a node available. +// there is no such node available. CallGraphEntry* CallGraphNode::LookupGlobalVar(const GlobalVar& gv) { CHECK(gv.defined()); @@ -151,7 +151,7 @@ std::vector CallGraphNode::TopologicalOrder() const { for (const auto& it : entries) { // Keep tracking the nodes that have been visited. auto topo = it->TopologicalOrder(&visited); - // Preprend the collected items. The intermeidate nodes that are shared by + // Prepend the collected items. The intermediate nodes that are shared by // multiple entries are guaranteed to be collected when visiting the // previous entries. Therefore, topological order remains. ret.insert(ret.begin(), topo.begin(), topo.end()); @@ -173,10 +173,10 @@ std::vector CallGraphNode::TopologicalOrder() const { return ret; } -// A BSF traverser is used to collect the nodes in a CallGraphEntry. The nodes +// BSF traversal is used to collect the nodes in a CallGraphEntry. The nodes // that are visited by previous CallGraphEntry entries can be memoized. This // helps us to make sure no entry will be visited multiple times when collecting -// the nodes for an entir call graph. +// the nodes for an entire call graph. std::vector CallGraphEntry::TopologicalOrder( CallGraphEntrySet* visited) const { std::vector ret;