diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 6e2209d51950..8abee12a5b98 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -457,6 +457,18 @@ struct VarUsageInfo { */ VarUsageInfo CollectVarUsage(const Expr& expr); +/*! + * \brief Perform a liveness analysis on the function, indicating which variables + * are live at which location in the function. + * + * \param fn The function to be analyzed. + * \return An array of arrays of live variables per binding in the function. + * The array is indexed based on the corresponding control flow graph, + * so use `ExtractCFG` and `GetBindingIndex` to match locations in `fn` + * to indices in the result. + */ +Array> LivenessAnalysis(const Function& fn); + /*! * \brief Remove unused statements inside DataflowBlocks. * diff --git a/include/tvm/relax/dataflow_analysis.h b/include/tvm/relax/dataflow_analysis.h new file mode 100644 index 000000000000..af5823a82b84 --- /dev/null +++ b/include/tvm/relax/dataflow_analysis.h @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_analysis.h + * \brief A reusable framework for dataflow analysis in Relax. + * Based on Adrian Sampson's course material: + * https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ + * Do not confuse with dataflow pattern matching (does not use this machinery) + */ + +#ifndef TVM_RELAX_DATAFLOW_ANALYSIS_H_ +#define TVM_RELAX_DATAFLOW_ANALYSIS_H_ + +#include +#include +#include + +#include + +namespace tvm { +namespace relax { + +/*! \brief For dataflow analysis, we need to have a control flow graph. + * We will organize this graphs by bindings, which allows analyses to + * state their results for each binding in a SeqExpr. + * + * There are a few cases that have to be handled: + * 1. A normal binding (most common)ICHECK + * 2. The condition expression in an If node (a "split" point) + * 3. A merge point (the variable to which an If node is bound: it is a "merge" between + * the SeqExprs in the true and false branches) + * 4. The body expression in a SeqExpr (not actually bound) + */ +enum class BindingNodeKind { kBinding = 0, kIfCond = 1, kIfMerge = 2, kSeqBody = 3 }; + +class GraphBindingNode : public Object { + public: + /*! \brief The SeqExpr the binding resides in. */ + SeqExpr seq; + + /*! \brief The arguments to the binding. Only the first binding in the graph has arguments + * (i.e., the function arguments). */ + Array args; + + /*! \brief Index of the binding block in the SeqExpr where the binding is found. + * Convention: We put the SeqExpr body at one block past the final block. */ + size_t block_idx; + + /*! \brief Index of the binding within the binding block corresponding to this binding. + * Convention: Both the If condition and merge are mapped to the same index. + * We use the kind to distinguish. */ + size_t binding_idx; + + /*! \brief The kind of binding this is. */ + BindingNodeKind kind; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("seq", &seq); + v->Visit("args", &args); + v->Visit("block_idx", &block_idx); + v->Visit("binding_idx", &binding_idx); + v->Visit("kind", &kind); + } + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.analysis.GraphBinding"; + TVM_DECLARE_BASE_OBJECT_INFO(GraphBindingNode, Object); +}; + +/*! \brief Representation of a binding in the control flow graph */ +class GraphBinding : public ObjectRef { + public: + /*! + * \brief Create a GraphBinding. See the docs on GraphBindingNode for further details. + * + * \param seq: The SeqExpr in which the binding resides. + * \param args: The arguments to the binding (only nonempty for the first binding: + * these will be the function arguments) + * \param block_idx: The index of the BindingBlock in the SeqExpr + * where the binding resides (for the return expression, use one past the final block). + * \param binding_idx: The index of the binding in the BindingBlock corresponding to the binding. + * \param kind: The kind of binding this is. (Used especially to distinguish If node conditions + * from the merge after the If) + */ + TVM_DLL GraphBinding(const SeqExpr& seq, const Array& args, size_t block_idx, + size_t binding_idx, BindingNodeKind kind); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(GraphBinding, ObjectRef, GraphBindingNode); +}; + +/* A control flow graph corresponding to a function. + */ +class ControlFlowGraphNode : public Object { + public: + /*! \brief The bindings in the graph. 0 is the entry point. */ + Array bindings; + /*! \brief The ith member is the list of predecessors (indices) to binding i in bindings. */ + Array> preds; + /*! \brief The ith member is the list of successors (indices) to binding i in bindings. */ + Array> succs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("bindings", &bindings); + v->Visit("preds", &preds); + v->Visit("succs", &succs); + } + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.analysis.ControlFlowGraph"; + TVM_DECLARE_BASE_OBJECT_INFO(ControlFlowGraphNode, Object); +}; + +class ControlFlowGraph : public ObjectRef { + public: + /*! + * \brief Create a ControlFlowGraph. + * + * \param bindings: The bindings in the graph + * \param preds: List of lists of predecessors to each binding. + * \param succs: List of lists of successors to each binding. + */ + TVM_DLL ControlFlowGraph(const Array& bindings, const Array>& preds, + const Array>& succs); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ControlFlowGraph, ObjectRef, ControlFlowGraphNode); +}; + +/*! + * \brief Extracts the control flow graph for a Relax function. + * \param func The function. This conversion expects it to be normalized. + * \return The control flow graph corresponding to the function. + */ +ControlFlowGraph ExtractCFG(const Function& func); + +/*! + * \brief Generic implementation of dataflow analysis, based on + * Adrian Sampson's course material, except binding by binding + * instead of basic block by basic block: + * https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ + * + * The analysis creates input and output maps (mapping binding indices to a domain), + * sets the initial input and output for each binding to the init value, and then + * performs a traversal of the CFG (BFS in this implementation, since unlike the general case, + * we do not have loops) and uses the transfer and merge function to update the inputs and + * outputs. The analysis can proceed forwards (from binding 0 onwards) or backwards (from the + * last binding back), flipping the roles of the input and output maps in the cases. + * + * \param forward Whether to perform a forward or backward analysis + * \param cfg The input control flow graph + * \param init The value corresponding to an initial domain + * \param transfer_func Given an input domain and a binding, determine the resulting domain + * \param merge_func Given a set of domains, combine them to form a single new domain + * (note: in Relax, a binding can never have more than two predecessors/successors) + * + * \return Two arrays, the first being the "input map" (domain being passed *into* + * each binding in the CFG) and the second being the "output map" (the domain + * being passed *out of* the corresponding binding) + */ +std::pair, Array> DataflowAnalysis( + const ControlFlowGraph& cfg, const ObjectRef& init, + std::function transfer_func, + std::function merge_func, bool forward = true); + +/*! \brief A helper function. Given an index into a SeqExpr, give the index of the GraphBinding + * in the CFG. + * + * \param cfg The control flow graph. + * \param seq The target SeqExpr. + * \param block_idx The target block in the SeqExpr. + * Convention: Use one past the last block to indicate the SeqExpr body. + * \param binding_idx The target binding in the target block. + * \param match_cond If the RHS of the target binding is an IfExpr, then if match_cond is true, + * the returned index will be for the condition node; otherwise it will be for the merge node. + */ +size_t GetBindingIndex(const ControlFlowGraph& cfg, const SeqExpr& seq, size_t block_idx, + size_t binding_idx, bool match_cond); + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_DATAFLOW_ANALYSIS_H_ diff --git a/python/tvm/relax/analysis/__init__.py b/python/tvm/relax/analysis/__init__.py index d8454a02cc84..d2c4c889b5d7 100644 --- a/python/tvm/relax/analysis/__init__.py +++ b/python/tvm/relax/analysis/__init__.py @@ -32,6 +32,7 @@ get_static_type, get_var2val, has_reshape_pattern, + liveness_analysis, name_to_binding, post_order_visit, remove_all_unused, diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 38f5ea2fea0e..4ddcaf2fe3f6 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -21,7 +21,7 @@ configuring the passes and scripting them in Python. """ -from typing import Dict, List, Optional, Union, Callable +from typing import Dict, List, Optional, Set, Union, Callable from enum import IntEnum import tvm @@ -407,6 +407,29 @@ def udchain(dfb: DataflowBlock) -> Dict[Var, List[Var]]: return _ffi_api.udchain(dfb) # type: ignore +def liveness_analysis(func: Function) -> List[Set[Var]]: + """ + Perform a liveness analysis on the given function, returning a set of + the variables live in the given program location. + + Parameters + ---------- + func: Function + The function to be analyzed + + Returns + ------- + ret: List[Set[Var]] + The set of live variables for each binding in the function. + The indexing is determined by the control flow graph, so + use `extract_cfg` and `get_binding_index` to find the index + for a given program location in the list. + """ + live_lists = _ffi_api.LivenessAnalysis(func) + # convert the lists to sets + return [set(live_list) for live_list in live_lists] + + def name_to_binding(func: Function) -> Dict[str, List[Binding]]: """Return a map from variable name to its bindings.""" return _ffi_api.name_to_binding(func) # type: ignore diff --git a/python/tvm/relax/analysis/dataflow_analysis.py b/python/tvm/relax/analysis/dataflow_analysis.py new file mode 100644 index 000000000000..9c233f27edc0 --- /dev/null +++ b/python/tvm/relax/analysis/dataflow_analysis.py @@ -0,0 +1,217 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Python bindings for the dataflow analysis framework +""" +from enum import Enum +from typing import Any, Callable, List, Tuple +import tvm +from tvm.ir.base import Node +from tvm.relax.expr import SeqExpr, Function, Var +from . import _ffi_api + + +class BindingNodeKind(Enum): + Binding = 0 + IfCond = 1 + IfMerge = 2 + SeqBody = 3 + + +@tvm._ffi.register_object("relax.analysis.GraphBinding") +class GraphBinding(Node): + """Representation of a binding in a control flow graph""" + + seq: SeqExpr + args: List[Var] + block_idx: int + binding_idx: int + kind: BindingNodeKind + + def __init__( + self, + seq: SeqExpr, + args: List[Var], + block_idx: int, + binding_idx: int, + kind: BindingNodeKind, + ): + """ + Create a graph binding + + Parameters + ---------- + seq: SeqExpr + The SeqExpr that contains the binding + + args: List[Var] + Arguments taken by the binding (only used for the entry binding: + these will be the function arguments. Otherwise, this array should be empty.) + + block_idx: int + The index of the block in the SeqExpr's block list where the binding resides + (convention: for the SeqExpr body, we will use one past the final block) + + binding_idx: int + The index of the binding in the binding block corresponding to this binding. + + kind: BindingNodeKind + The kind of binding. We distinguish between ordinary bindings, + If conditions, If merges (the var bound to the result of the If node), + and the body of the SeqExpr. + """ + self.__init_handle_by_constructor__( + _ffi_api.GraphBinding, + seq, + args, + block_idx, + binding_idx, + kind, + ) # type: ignore + + +@tvm._ffi.register_object("relax.analysis.ControlFlowGraph") +class ControlFlowGraph(Node): + """Representation of a control flow graph, marking the successors + and predecessors to all basic blocks""" + + def __init__( + self, bindings: List[GraphBinding], preds: List[List[int]], succs: List[List[int]] + ): + """ + Instantiate a control flow graph + + Parameters + ---------- + bindings: List[GraphBnding] + List of bindings in the graph + + preds: List[List[int]] + The ith member is the list of predecessors to bindings[i] (given as indices in bindings) + + succs: List[List[int]] + The ith member is the list of successors to bindings[i] (given as indices in bindings) + """ + if len(bindings) != len(preds) or len(bindings) != len(succs): + raise ValueError("The lengths of blocks, preds, and succs must all match.") + + self.__init_handle_by_constructor__( + _ffi_api.ControlFlowGraph, bindings, preds, succs + ) # type: ignore + + +def extract_cfg(func: Function) -> ControlFlowGraph: + """ + Given a Relax function, produces the corresponding control flow graph. + The function is expected to have been normalized. + + Parameters + ---------- + func: Function + A Relax function. Must be in normal form. + + Returns + ------- + graph: ControlFlowGraph + Control flow graph corresponding to the function. + """ + return _ffi_api.ExtractCFG(func) # type: ignore + + +def get_binding_index( + cfg: ControlFlowGraph, seq: SeqExpr, block_idx: int, binding_idx: int, match_cond: bool = False +) -> int: + """ + Helper function. Given a control flow graph and a seq expression with a block index + and binding index, return the index of the corresponding GraphBinding in the CFG + + Parameters + ---------- + cfg: ControlFlowGraph + The control flow graph. + + seq: SeqExpr + The target SeqExpr. + + block_idx: int + The index of the target block in seq. + Convention: If the target is `seq.body`, block_idx should be one past the last block + (i.e., it should be equal to `len(seq.blocks)`). + + binding_idx: int + The index of the target binding in the target block. + + match_cond: bool + If true and the target binding in seq is an IfNode, then this function will return + the binding index corresponding to the If condition. + If false, then this function will return the binding index corresponding to the If merge. + + Returns + ------- + idx: int + The index of the corresponding GraphBindindg in `cfg.bindings`. + """ + return _ffi_api.GetBindingIndex(cfg, seq, block_idx, binding_idx, match_cond) # type: ignore + + +def dataflow_analysis( + cfg: ControlFlowGraph, + init: Any, + transfer_func: Callable[[GraphBinding, Any], Any], + merge_func: Callable[[Any, Any], Any], + forward: bool = True, +) -> Tuple[List[Any], List[Any]]: + """ + Generic dataflow analysis framework, based on Adrian Sampson's course notes, + except binding by binding instead of basic block by basic block: + https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ + + The analysis creates input and output maps (mapping binding indices to a domain), + sets the initial input and output for each binding to the init value, and then + performs a traversal of the CFG (BFS in this implementation, since unlike the general case, + we do not have loops) and uses the transfer and merge function to update the inputs and + outputs. The analysis can proceed forwards (from binding 0 onwards) or backwards (from the last + binding back), flipping the roles of the input and output maps in the cases. + + Parameters + ---------- + cfg: ControlFlowGraph + The input control flow graph + + init: Any + The initial value in the analysis domain to which all blocks should be initialized. + + transfer_func: Callable[[GraphBinding, Any], Any] + Given a binding and the input domain, compute the new output domain. + + merge_func: Callable[[Any, Any], Any] + When two output domains are fed into a single block (i.e., after an If branch), + the merge function is used to combine them into a single domain. + + forward: bool + If true, the analysis proceeds forwards (starting from binding 0 and going onwards). + If false, the analysis proceeds backwards (starting from the last binding and going back). + The input and output maps play the opposite roles in forward and backward analyses. + I.e., in a backward analysis, the "final output" is the input map entry for binding 0 + and the initial input is the output map entry for the last binding. + + Returns + ------- + ret: Tuple[List[Any], List[Any]] + A pair of the final input and output maps + """ + return _ffi_api.DataflowAnalysis(cfg, init, transfer_func, merge_func, forward) # type: ignore diff --git a/src/relax/analysis/dataflow_analysis.cc b/src/relax/analysis/dataflow_analysis.cc new file mode 100644 index 000000000000..151f4d18c749 --- /dev/null +++ b/src/relax/analysis/dataflow_analysis.cc @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/analysis/dataflow_analysis.cc + * \brief Implementation of functionality in dataflow_analysis.h + */ +#include +#include + +#include + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(GraphBindingNode); + +GraphBinding::GraphBinding(const SeqExpr& seq, const Array& args, size_t block_idx, + size_t binding_idx, BindingNodeKind kind) { + ObjectPtr n = make_object(); + n->seq = seq; + n->args = args; + n->block_idx = block_idx; + n->binding_idx = binding_idx; + n->kind = kind; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ControlFlowGraphNode); + +ControlFlowGraph::ControlFlowGraph(const Array& bindings, + const Array>& preds, + const Array>& succs) { + ObjectPtr n = make_object(); + n->bindings = bindings; + n->preds = preds; + n->succs = succs; + data_ = std::move(n); +} + +// Extracts a basic block and updates the running lists bindings, preds, and succs. +// The return value is the index of the final binding processed in the seq expression +// (useful for processing branches). +size_t ExtractCFGHelper(const SeqExpr& seq, const Array& args, size_t block_idx, + size_t binding_idx, std::vector current_preds, + std::vector* bindings, + std::vector>* preds, + std::vector>* succs) { + // case 1: We're past the end -> this is the block body (base case) + if (block_idx == seq->blocks.size()) { + bindings->push_back(GraphBinding(seq, args, block_idx, 0U, BindingNodeKind::kSeqBody)); + preds->push_back(current_preds); + // the final binding has no successors + succs->push_back({}); + return bindings->size() - 1; + } + + Binding binding = seq->blocks[block_idx]->bindings[binding_idx]; + Expr binding_value = GetBoundValue(binding); + + // case 2: Ordinary binding + if (!binding_value.as()) { + bindings->push_back(GraphBinding(seq, args, block_idx, binding_idx, BindingNodeKind::kBinding)); + size_t idx = bindings->size() - 1; + preds->push_back(current_preds); + // successor: the next binding (there will always be at least one binding after this, + // even if it's the seq body) + succs->push_back({idx + 1}); + } else { + // case 3: dealing with a branch + auto if_node = Downcast(binding_value); + // start with the cond node + bindings->push_back(GraphBinding(seq, args, block_idx, binding_idx, BindingNodeKind::kIfCond)); + size_t idx = bindings->size() - 1; + preds->push_back(current_preds); + // there will be another successor, which we will add after recursing down the branches + succs->push_back({idx + 1}); + size_t final_true_idx = ExtractCFGHelper(Downcast(if_node->true_branch), {}, 0U, 0U, + {idx}, bindings, preds, succs); + succs->at(idx).push_back(final_true_idx + 1); + size_t final_false_idx = ExtractCFGHelper(Downcast(if_node->false_branch), {}, 0U, 0U, + {idx}, bindings, preds, succs); + // now create the merge + bindings->push_back(GraphBinding(seq, {}, block_idx, binding_idx, BindingNodeKind::kIfMerge)); + size_t merge_idx = bindings->size() - 1; + preds->push_back({final_true_idx, final_false_idx}); + succs->push_back({merge_idx + 1}); + // update the successors of the final true and false indices as well + succs->at(final_true_idx).push_back(merge_idx); + succs->at(final_false_idx).push_back(merge_idx); + } + // move on to next binding + size_t next_block_idx = block_idx; + size_t next_binding_idx = binding_idx + 1; + if (next_binding_idx >= seq->blocks[block_idx]->bindings.size()) { + next_block_idx = block_idx + 1; + next_binding_idx = 0U; + } + return ExtractCFGHelper(seq, {}, next_block_idx, next_binding_idx, {bindings->size() - 1}, + bindings, preds, succs); +} + +ControlFlowGraph ExtractCFG(const Function& func) { + std::vector bindings; + std::vector> preds; + std::vector> succs; + ExtractCFGHelper(Downcast(func->body), func->params, 0U, 0U, {}, &bindings, &preds, + &succs); + + Array> pred_arr; + for (auto pred_vec : preds) { + Array pred_ints; + for (auto idx : pred_vec) { + pred_ints.push_back(Integer(idx)); + } + pred_arr.push_back(pred_ints); + } + Array> succ_arr; + for (auto succ_vec : succs) { + Array succ_ints; + for (auto idx : succ_vec) { + succ_ints.push_back(Integer(idx)); + } + succ_arr.push_back(succ_ints); + } + return ControlFlowGraph(Array(bindings), pred_arr, succ_arr); +} + +std::pair, Array> DataflowAnalysis( + const ControlFlowGraph& cfg, const ObjectRef& init, + std::function transfer_func, + std::function merge_func, bool forward) { + std::vector in_map; + std::vector out_map; + for (size_t i = 0; i < cfg->bindings.size(); i++) { + in_map.push_back(init); + out_map.push_back(init); + } + + // Modification from Adrian Sampson's version: + // Since there are no loops in our AST, one traversal through the CFG suffices. + // We will do BFS + std::queue worklist; + worklist.push((forward) ? 0 : cfg->bindings.size() - 1); + while (!worklist.empty()) { + size_t idx = worklist.front(); + worklist.pop(); + Array prev = (forward) ? cfg->preds[idx] : cfg->succs[idx]; + Array next = (forward) ? cfg->succs[idx] : cfg->preds[idx]; + std::vector* results = (forward) ? &out_map : &in_map; + std::vector* inputs = (forward) ? &in_map : &out_map; + + // Cases (for forward analysis): + // 0 predecessors: The first block in the function + // 1 predecessor: A branch in an If node (no merge needed) + // 2 predecessors: The merge block after an If node (merge needed) + // (Analogous for successors in backward analysis) + inputs->operator[](idx) = (prev.size() == 0) ? init + : (prev.size() == 1) ? results->at(prev[0].IntValue()) + : merge_func(results->at(prev[0].IntValue()), + results->at(prev[1].IntValue())); + results->operator[](idx) = transfer_func(cfg->bindings[idx], inputs->at(idx)); + + for (Integer next_idx : next) { + worklist.push(next_idx.IntValue()); + } + } + + return {Array(in_map), Array(out_map)}; +} + +size_t GetBindingIndex(const ControlFlowGraph& cfg, const SeqExpr& seq, size_t block_idx, + size_t binding_idx, bool match_cond) { + bool is_body = (block_idx == seq->blocks.size()); + bool is_if = + (!is_body && (GetBoundValue(seq->blocks[block_idx]->bindings[binding_idx]).as())); + + // This is an inefficient linear scan; it could be improved by keeping a map of + // SeqExprs to indices in the CFG data structure. + // That should be considered if this function poses performance issues (unlikely). + for (size_t i = 0; i < cfg->bindings.size(); i++) { + auto binding = cfg->bindings[i]; + if (binding->seq != seq) { + continue; + } + if (is_body && binding->kind == BindingNodeKind::kSeqBody) { + return i; + } + if (binding->block_idx == block_idx && binding->binding_idx == binding_idx) { + if (!is_if || (match_cond && binding->kind == BindingNodeKind::kIfCond) || + (!match_cond && binding->kind == BindingNodeKind::kIfMerge)) { + return i; + } + } + } + CHECK(false) << "Target binding does not appear in the given CFG"; + return cfg->bindings.size(); +} + +TVM_REGISTER_GLOBAL("relax.analysis.GraphBinding") + .set_body_typed([](const SeqExpr& seq, const Array& args, size_t block_idx, + size_t binding_idx, int kind) { + return GraphBinding(seq, args, block_idx, binding_idx, static_cast(kind)); + }); + +TVM_REGISTER_GLOBAL("relax.analysis.ControlFlowGraph") + .set_body_typed([](const Array& blocks, const Array>& preds, + const Array>& succs) { + return ControlFlowGraph(blocks, preds, succs); + }); + +TVM_REGISTER_GLOBAL("relax.analysis.ExtractCFG").set_body_typed(ExtractCFG); + +TVM_REGISTER_GLOBAL("relax.analysis.DataflowAnalysis") + .set_body_typed([](const ControlFlowGraph& cfg, const ObjectRef& init, PackedFunc transfer_func, + PackedFunc merge_func, bool forward) { + auto ret = DataflowAnalysis(cfg, init, transfer_func, merge_func, forward); + return Array({ret.first, ret.second}); + }); + +// need to turn the size_t's into ints in order to cross the C++<->Python boundary +TVM_REGISTER_GLOBAL("relax.analysis.GetBindingIndex") + .set_body_typed([](const ControlFlowGraph& cfg, const SeqExpr& seq, int block_idx, + int binding_idx, bool match_cond) -> int { + return GetBindingIndex(cfg, seq, block_idx, binding_idx, match_cond); + }); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/analysis/liveness.cc b/src/relax/analysis/liveness.cc new file mode 100644 index 000000000000..548026e3bb67 --- /dev/null +++ b/src/relax/analysis/liveness.cc @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/analysis/liveness.cc + * \brief Implementation of liveness analysis + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +// just sets of vars. the bool value is unnecessary +using Domain = Map; + +Domain transfer_func(const GraphBinding& binding, const ObjectRef& input) { + Domain in_domain = Downcast(input); + Domain new_domain(in_domain); + + // 1. If a var that appears in the RHS of the binding, add it (it's live) + // 2. Remove the bound var (it is not live prior to being bound) + Array vars_used; + Optional var_bound; + if (binding->kind == BindingNodeKind::kSeqBody) { + vars_used = AllVars(binding->seq->body); + } else if (binding->kind == BindingNodeKind::kIfCond) { + Binding b = binding->seq->blocks[binding->block_idx]->bindings[binding->binding_idx]; + Expr cond = Downcast(GetBoundValue(b))->cond; + vars_used = AllVars(cond); + } else if (binding->kind == BindingNodeKind::kIfMerge) { + // no vars are used in the merge + vars_used = {}; + // define the merge var + var_bound = binding->seq->blocks[binding->block_idx]->bindings[binding->binding_idx]->var; + } else { + // the ordinary binding case + Binding b = binding->seq->blocks[binding->block_idx]->bindings[binding->binding_idx]; + Expr bound_value = GetBoundValue(b); + // For a function literal, we only care about the free vars + // (those captured by the closure). + // In all other cases, we want any var, but since the RHS would not contain + // any bindings of its own, that turns out to be the same as the free vars. + vars_used = FreeVars(bound_value); + var_bound = b->var; + } + + for (auto var : vars_used) { + new_domain.Set(var, Bool(true)); + } + + // the var bound is killed + if (var_bound.defined()) { + new_domain.erase(var_bound.value()); + } + + // technically, we could kill the args too, + // but they are not actually *bound* at the first binding + + return new_domain; +} + +// simply combine sets of live vars to merge +Domain merge_func(const ObjectRef& domain1, const ObjectRef& domain2) { + Domain merged; + for (auto kv : Downcast(domain1)) { + merged.Set(kv.first, kv.second); + } + for (auto kv : Downcast(domain2)) { + merged.Set(kv.first, kv.second); + } + return merged; +} + +Array> LivenessAnalysis(const Function& func) { + // initial domain is empty + Domain init_domain; + ControlFlowGraph cfg = ExtractCFG(func); + std::pair results = + DataflowAnalysis(cfg, init_domain, transfer_func, merge_func, false); + + // we will return the input map but convert the maps into arrays for simplicity + + // The map is done for safety, since directly doing Downcast>(results.first) + // would *not* check the contents of results.first. + Array res_objs = Downcast>(results.first); + Array in_map = res_objs.Map([](const ObjectRef& obj) { return Downcast(obj); }); + + Array> ret; + for (const Domain& d : in_map) { + Array arr; + for (auto kv : d) { + arr.push_back(kv.first); + } + ret.push_back(arr); + } + return ret; +} + +TVM_REGISTER_GLOBAL("relax.analysis.LivenessAnalysis").set_body_typed(LivenessAnalysis); + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_analysis_liveness_analysis.py b/tests/python/relax/test_analysis_liveness_analysis.py new file mode 100644 index 000000000000..9483073ef08a --- /dev/null +++ b/tests/python/relax/test_analysis_liveness_analysis.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Set +import tvm +import tvm.testing +from tvm.script import ir as I, relax as R +from tvm.relax import Var +from tvm.relax.analysis import liveness_analysis + + +def assert_live_set(live_set: Set[Var], var_names: Set[str]) -> None: + assert len(live_set) == len(var_names) + for var in live_set: + assert var.name_hint in var_names + + +def test_simple_liveness(): + @I.ir_module + class SimpleFunc: + @R.function + def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + y = R.add(x, x) # live: x + z = R.add(y, y) # live: y + return z # live: z + + live_sets = liveness_analysis(SimpleFunc["main"]) + assert_live_set(live_sets[0], {"x"}) + assert_live_set(live_sets[1], {"y"}) + assert_live_set(live_sets[2], {"z"}) + + +def test_liveness_with_branches(): + @I.ir_module + class BranchingFunc: + @R.function + def main( + x: R.Tensor((), dtype="int32"), + y: R.Tensor((), dtype="int32"), + cond: R.Tensor((), dtype="bool"), + ) -> R.Tensor((), dtype="int32"): + z = R.add(x, x) # live: x, y, cond + q = R.add(z, z) # live: y, z, cond + if cond: # live: q, y, cond + r = R.subtract(q, y) # live: q, y + s = R.multiply(r, r) # live: r + # end of seq: the R.multiply will actually be bound to a fresh var + # and s will be used as the binding for the entire If node + else: + r = R.multiply(q, q) # live: q, y + s = R.subtract(r, y) # live: r, y + # end of seq: the R.subtract will actually be bound to a fresh var + # and s will be used as the binding for the entire If node + # merge point: nothing is live (s is the variable bound at the merge) + t = R.add(s, s) # live: s + u = R.multiply(t, s) # live: t, s + return u # live: u + + live_sets = liveness_analysis(BranchingFunc["main"]) + assert_live_set(live_sets[0], {"x", "y", "cond"}) + assert_live_set(live_sets[1], {"y", "z", "cond"}) + assert_live_set(live_sets[2], {"q", "y", "cond"}) + assert_live_set(live_sets[3], {"q", "y"}) + assert_live_set(live_sets[4], {"r"}) + # the name is created by the parser and will be a placeholder so this is the best we can do + assert len(live_sets[5]) == 1 and ( + BranchingFunc["main"].body.blocks[0].bindings[2].value.true_branch.body in live_sets[5] + ) + assert_live_set(live_sets[6], {"q", "y"}) + assert_live_set(live_sets[7], {"r", "y"}) + assert len(live_sets[8]) == 1 and ( + BranchingFunc["main"].body.blocks[0].bindings[2].value.false_branch.body in live_sets[8] + ) + assert_live_set(live_sets[9], {}) + assert_live_set(live_sets[10], {"s"}) + assert_live_set(live_sets[11], {"t", "s"}) + assert_live_set(live_sets[12], {"u"}) + + +def test_liveness_inner_func(): + @I.ir_module + class InnerFunc: + @R.function + def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + y = R.add(x, x) # live: x + z = R.add(y, y) # live: x, y + + # the inner func captures x and y and so counts as a use of both + # live: x, y, z + @R.function + def inner(q: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + # (note: we would need to do liveness analysis of the inner func + # separately to get liveness info for these locations) + r = R.add(x, q) # live: x, y, q (and z from outside) + s = R.multiply(y, r) # live: y, r (and z from outside) + return s # live: s (and z from outside) + + w = inner(z) # live: inner, z + return w # live: w + + live_sets = liveness_analysis(InnerFunc["main"]) + assert_live_set(live_sets[0], {"x"}) + assert_live_set(live_sets[1], {"x", "y"}) + assert_live_set(live_sets[2], {"x", "y", "z"}) + assert_live_set(live_sets[3], {"inner", "z"}) + assert_live_set(live_sets[4], {"w"}) + + # let's also analyze the inner func (note: we don't have a way to indicate + # that z is live from outside the func) + inner_live = liveness_analysis(InnerFunc["main"].body.blocks[0].bindings[2].value) + assert_live_set(inner_live[0], {"x", "y", "q"}) + assert_live_set(inner_live[1], {"y", "r"}) + assert_live_set(inner_live[2], {"s"}) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_dataflow_analysis.py b/tests/python/relax/test_dataflow_analysis.py new file mode 100644 index 000000000000..d61dfcc509e7 --- /dev/null +++ b/tests/python/relax/test_dataflow_analysis.py @@ -0,0 +1,517 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Callable, List, Optional +import tvm +from tvm import relax +from tvm.relax.analysis.dataflow_analysis import ( + ControlFlowGraph, + extract_cfg, + dataflow_analysis, + BindingNodeKind, + get_binding_index, +) +from tvm.script import ir as I, relax as R +import tvm.testing + + +def assert_pred_succ_lists(graph: ControlFlowGraph, expected_preds: List[List[int]]) -> None: + assert tuple([tuple(preds) for preds in graph.preds]) == tuple( + [tuple(exp_preds) for exp_preds in expected_preds] + ) + + expected_succs = [[] for preds in expected_preds] + # we can automatically invert the predecessor list + # this also guarantees consistency + for i, pred_list in enumerate(expected_preds): + for pred in pred_list: + expected_succs[pred].append(i) + + assert tuple([tuple(succs) for succs in graph.succs]) == tuple( + [tuple(exp_succs) for exp_succs in expected_succs] + ) + + +def assert_binding_fields( + graph: ControlFlowGraph, + idx: int, + block_idx: int, + binding_idx: int, + kind: BindingNodeKind = BindingNodeKind.Binding, + args: Optional[List[relax.Var]] = None, +): + binding = graph.bindings[idx] + assert binding.block_idx == block_idx + assert binding.binding_idx == binding_idx + assert binding.kind == kind.value + if args is not None: + assert len(binding.args) == len(args) + for i in range(len(args)): + assert binding.args[i] == args[i] + + +# assert that the SeqExprs for each bindings match within groups and do not match other groups +def assert_distinct_seqs(cfg: ControlFlowGraph, *groups: List[int]): + for i, group in enumerate(groups): + if len(group) == 0: + continue + for idx in group[1:]: + assert cfg.bindings[idx].seq == cfg.bindings[group[0]].seq + for other_group in groups[i + 1 :]: + for idx in other_group: + assert cfg.bindings[group[0]].seq != cfg.bindings[idx].seq + + +def test_trivial_CFG(): + @I.ir_module + class TrivialFunc: + @R.function + def main() -> R.Tensor((), "int32"): + return R.const(1, dtype="int32") + + graph = extract_cfg(TrivialFunc["main"]) + assert len(graph.bindings) == 1 + assert_pred_succ_lists(graph, [[]]) + assert_binding_fields(graph, 0, 0, 0, kind=BindingNodeKind.SeqBody) + + +def test_sequence_of_bindings(): + @I.ir_module + class FuncWithBindings: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + y = R.add(x, x) + z = R.add(y, x) + q = R.multiply(z, x) + return q + + graph = extract_cfg(FuncWithBindings["main"]) + assert len(graph.bindings) == 4 + assert_pred_succ_lists(graph, [[], [0], [1], [2]]) + assert_binding_fields(graph, 0, 0, 0, args=[FuncWithBindings["main"].params[0]]) + assert_binding_fields(graph, 1, 0, 1) + assert_binding_fields(graph, 2, 0, 2) + assert_binding_fields(graph, 3, 1, 0, kind=BindingNodeKind.SeqBody) + + +def test_dataflow_block(): + @I.ir_module + class FuncWithDataflow: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + y = R.add(x, x) + z = R.add(y, y) + with R.dataflow(): + q = R.multiply(z, z) + r = R.add(q, q) + R.output(r) + s = R.add(r, r) + with R.dataflow(): + t = R.multiply(s, s) + u = R.add(t, t) + R.output(u) + return u + + graph = extract_cfg(FuncWithDataflow["main"]) + assert len(graph.bindings) == 8 + assert_pred_succ_lists(graph, [[], [0], [1], [2], [3], [4], [5], [6]]) + assert_binding_fields(graph, 0, 0, 0, args=FuncWithDataflow["main"].params) + assert_binding_fields(graph, 1, 0, 1) + assert_binding_fields(graph, 2, 1, 0) + assert_binding_fields(graph, 3, 1, 1) + assert_binding_fields(graph, 4, 2, 0) + assert_binding_fields(graph, 5, 3, 0) + assert_binding_fields(graph, 6, 3, 1) + assert_binding_fields(graph, 7, 4, 0, kind=BindingNodeKind.SeqBody) + + +def test_simple_branch(): + @I.ir_module + class SimpleBranch: + @R.function + def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): + if cond: + x = R.const(1, dtype="int32") + y = R.add(x, x) + z = R.multiply(y, y) + else: + x = R.const(2, dtype="int32") + y = R.add(x, x) + z = R.multiply(y, y) + return z + + graph = extract_cfg(SimpleBranch["main"]) + + # cond binding + 3 bindings in true branch + true branch end + # + 3 bindings in false branch + false branch end + merge + seq body + assert len(graph.bindings) == 11 + assert_pred_succ_lists(graph, [[], [0], [1], [2], [3], [0], [5], [6], [7], [4, 8], [9]]) + + assert_binding_fields( + graph, 0, 0, 0, kind=BindingNodeKind.IfCond, args=SimpleBranch["main"].params + ) + assert_binding_fields(graph, 1, 0, 0) + assert_binding_fields(graph, 2, 0, 1) + assert_binding_fields(graph, 3, 0, 2) + assert_binding_fields(graph, 4, 1, 0, kind=BindingNodeKind.SeqBody) + assert_binding_fields(graph, 5, 0, 0) + assert_binding_fields(graph, 6, 0, 1) + assert_binding_fields(graph, 7, 0, 2) + assert_binding_fields(graph, 8, 1, 0, kind=BindingNodeKind.SeqBody) + assert_binding_fields(graph, 9, 0, 0, kind=BindingNodeKind.IfMerge) + assert_binding_fields(graph, 10, 1, 0, kind=BindingNodeKind.SeqBody) + assert_distinct_seqs(graph, [0, 9], [1, 4], [5, 8]) + + +def test_bindings_after_branch(): + @I.ir_module + class BranchAndBind: + @R.function + def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): + x = R.const(1, dtype="int32") + y = R.add(x, x) + if cond: + z = R.multiply(y, y) + else: + z = R.add(y, y) + q = R.add(z, z) + return q + + graph = extract_cfg(BranchAndBind["main"]) + assert len(graph.bindings) == 10 + assert_pred_succ_lists(graph, [[], [0], [1], [2], [3], [2], [5], [4, 6], [7], [8]]) + assert_binding_fields(graph, 0, 0, 0, args=BranchAndBind["main"].params) + assert_binding_fields(graph, 1, 0, 1) + assert_binding_fields(graph, 2, 0, 2, kind=BindingNodeKind.IfCond) + assert_binding_fields(graph, 3, 0, 0) + assert_binding_fields(graph, 4, 1, 0, kind=BindingNodeKind.SeqBody) + assert_binding_fields(graph, 5, 0, 0) + assert_binding_fields(graph, 6, 1, 0, kind=BindingNodeKind.SeqBody) + assert_binding_fields(graph, 7, 0, 2, kind=BindingNodeKind.IfMerge) + assert_binding_fields(graph, 8, 0, 3) + assert_binding_fields(graph, 9, 1, 0, kind=BindingNodeKind.SeqBody) + assert_distinct_seqs(graph, [0, 2, 7, 9], [3, 4], [5, 6]) + + +def test_branch_with_multiple_blocks(): + @I.ir_module + class LongBranches: + @R.function + def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): + if cond: + x = R.const(1, dtype="int32") + y = R.add(x, x) + with R.dataflow(): + z = R.multiply(y, y) + w = R.add(z, z) + v = R.multiply(w, w) + R.output(v) + q = R.add(v, v) + r = R.multiply(q, q) + else: + x = R.const(2, dtype="int32") + y = R.multiply(x, x) + with R.dataflow(): + z = R.add(y, y) + w = R.multiply(z, z) + v = R.add(w, w) + R.output(v) + q = R.multiply(v, v) + r = R.add(q, q) + return r + + graph = extract_cfg(LongBranches["main"]) + # empty entry block, one block for each branch, and an empty exit block + assert len(graph.bindings) == 19 + assert_pred_succ_lists( + graph, + [ + [], + [0], + [1], + [2], + [3], + [4], + [5], + [6], + [7], + [0], + [9], + [10], + [11], + [12], + [13], + [14], + [15], + [8, 16], + [17], + ], + ) + + assert_binding_fields( + graph, 0, 0, 0, kind=BindingNodeKind.IfCond, args=LongBranches["main"].params + ) + assert_binding_fields(graph, 1, 0, 0) + assert_binding_fields(graph, 2, 0, 1) + assert_binding_fields(graph, 3, 1, 0) + assert_binding_fields(graph, 4, 1, 1) + assert_binding_fields(graph, 5, 1, 2) + assert_binding_fields(graph, 6, 2, 0) + assert_binding_fields(graph, 7, 2, 1) + assert_binding_fields(graph, 8, 3, 0, kind=BindingNodeKind.SeqBody) + assert_binding_fields(graph, 9, 0, 0) + assert_binding_fields(graph, 10, 0, 1) + assert_binding_fields(graph, 11, 1, 0) + assert_binding_fields(graph, 12, 1, 1) + assert_binding_fields(graph, 13, 1, 2) + assert_binding_fields(graph, 14, 2, 0) + assert_binding_fields(graph, 15, 2, 1) + assert_binding_fields(graph, 16, 3, 0, kind=BindingNodeKind.SeqBody) + assert_binding_fields(graph, 17, 0, 0, kind=BindingNodeKind.IfMerge) + assert_binding_fields(graph, 18, 1, 0, kind=BindingNodeKind.SeqBody) + assert_distinct_seqs(graph, [0, 17, 18], [1, 8], [9, 16]) + + +def test_nested_branches(): + @I.ir_module + class NestedBranches: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + cond1 = R.const(True, dtype="bool") + if cond1: + cond2 = R.const(False, dtype="bool") + if cond2: + y = R.add(x, x) + else: + y = R.multiply(x, x) + z = R.add(y, y) + else: + cond3 = R.const(True, dtype="bool") + if cond3: + y = R.multiply(x, x) + else: + y = R.add(x, x) + z = R.multiply(y, y) + return z + + graph = extract_cfg(NestedBranches["main"]) + assert len(graph.bindings) == 22 + assert_pred_succ_lists( + graph, + [ + [], # first binding + [0], # branch cond + [1], # first binding in true branch + [2], # mested if condition + [3], # binding inside nested true branch + [4], # end of nested true branch + [3], # binding inside nested false branch + [6], # end of nested false branch + [5, 7], # merge for nested if + [8], # binding after nested if + [9], # end of outer true branch + [1], # first binding in false branch + [11], # nested if condition, + [12], # binding inside nested true branch + [13], # end of nested true branch + [12], # binding inside nested false branch + [15], # end of nested false branch + [14, 16], # merge after nested if + [17], # binding after nested if + [18], # end of outer false branch + [10, 19], # merge after outer if + [20], # end of body + ], + ) + + assert_binding_fields(graph, 0, 0, 0, args=NestedBranches["main"].params) + assert_binding_fields(graph, 1, 0, 1, kind=BindingNodeKind.IfCond) + assert_binding_fields(graph, 2, 0, 0) + assert_binding_fields(graph, 3, 0, 1, kind=BindingNodeKind.IfCond) + assert_binding_fields(graph, 4, 0, 0) + assert_binding_fields(graph, 5, 1, 0, kind=BindingNodeKind.SeqBody) + assert_binding_fields(graph, 6, 0, 0) + assert_binding_fields(graph, 7, 1, 0, kind=BindingNodeKind.SeqBody) + assert_binding_fields(graph, 8, 0, 1, kind=BindingNodeKind.IfMerge) + assert_binding_fields(graph, 9, 0, 2) + assert_binding_fields(graph, 10, 1, 0, kind=BindingNodeKind.SeqBody) + assert_binding_fields(graph, 11, 0, 0) + assert_binding_fields(graph, 12, 0, 1, kind=BindingNodeKind.IfCond) + assert_binding_fields(graph, 13, 0, 0) + assert_binding_fields(graph, 14, 1, 0, kind=BindingNodeKind.SeqBody) + assert_binding_fields(graph, 15, 0, 0) + assert_binding_fields(graph, 16, 1, 0, kind=BindingNodeKind.SeqBody) + assert_binding_fields(graph, 17, 0, 1, kind=BindingNodeKind.IfMerge) + assert_binding_fields(graph, 18, 0, 2) + assert_binding_fields(graph, 19, 1, 0, kind=BindingNodeKind.SeqBody) + assert_binding_fields(graph, 20, 0, 1, kind=BindingNodeKind.IfMerge) + assert_binding_fields(graph, 21, 1, 0, kind=BindingNodeKind.SeqBody) + + assert_distinct_seqs( + graph, + [0, 1, 20, 21], + [2, 3, 8, 9, 10], + [4, 5], + [6, 7], + [11, 12, 17, 18, 19], + [13, 14], + [15, 16], + ) + + +def test_simple_analysis(): + @I.ir_module + class TrivialFunc: + @R.function + def main() -> R.Tensor((), "int32"): + return R.const(1, dtype="int32") + + # only one binding to consider here + init = {"a": 1} + + def transfer_func(_, domain): + # the input domain will be converted into an immutable TVM Map, + # so we have to create a new domain + new_domain = {} + for k, v in domain.items(): + new_domain[k] = v + new_domain["b"] = 2 + return new_domain + + # there will not be a merge here + merge_func = lambda domain1, _: domain1 + + def check_expected_maps(in_map, out_map): + # we expect the in map to be the init value and the out map to have the key b + assert len(in_map[0]) == 1 + assert in_map[0]["a"] == 1 + assert len(out_map[0]) == 2 + assert out_map[0]["a"] == 1 + assert out_map[0]["b"] == 2 + + cfg = extract_cfg(TrivialFunc["main"]) + in_map, out_map = dataflow_analysis(cfg, init, transfer_func, merge_func, forward=True) + check_expected_maps(in_map, out_map) + # backward will just flip in and out + in_map, out_map = dataflow_analysis(cfg, init, transfer_func, merge_func, forward=False) + check_expected_maps(out_map, in_map) + + +def test_simple_analysis_with_merge(): + @I.ir_module + class SimpleBranch: + @R.function + def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): + if cond: + x = R.const(1, dtype="int32") + y = R.add(x, x) + z = R.multiply(y, y) + else: + x = R.const(2, dtype="int32") + y = R.add(x, x) + z = R.multiply(y, y) + return z + + init = {"a": 1} + + def transfer_func(_, domain): + new_domain = {} + for k, v in domain.items(): + new_domain[k] = v + 1 + return new_domain + + def merge_func(domain1, domain2): + new_domain = {} + for k, v in domain1.items(): + new_domain[k] = v + for k, v in domain2.items(): + if k not in new_domain or (k in new_domain and new_domain[k] < v): + new_domain[k] = v + if "merge" not in new_domain: + new_domain["merge"] = 1 + return new_domain + + cfg = extract_cfg(SimpleBranch["main"]) + in_map, out_map = dataflow_analysis(cfg, init, transfer_func, merge_func, forward=True) + # start and true branch + for i in range(5): + assert in_map[i]["a"] == i + 1 + assert out_map[i]["a"] == i + 2 + # false branch + for i in range(5, 9): + assert in_map[i]["a"] == i - 3 + assert out_map[i]["a"] == i - 2 + # index 9 is the merge + assert in_map[9]["a"] == 6 + assert in_map[9]["merge"] == 1 + assert out_map[9]["a"] == 7 + assert out_map[9]["merge"] == 2 + # index 10 is the last + assert in_map[10]["a"] == 7 + assert in_map[10]["merge"] == 2 + assert out_map[10]["a"] == 8 + assert out_map[10]["merge"] == 3 + + in_map, out_map = dataflow_analysis(cfg, init, transfer_func, merge_func, forward=False) + # backward direction: start with index 10 + # end of seq through false branch + for i in range(6): + assert out_map[10 - i]["a"] == i + 1 + assert in_map[10 - i]["a"] == i + 2 + # true branch + for i in range(4): + assert out_map[4 - i]["a"] == i + 3 + assert in_map[4 - i]["a"] == i + 4 + # the if condition is the merge + assert out_map[0]["a"] == 7 + assert out_map[0]["merge"] == 1 + assert in_map[0]["a"] == 8 + assert in_map[0]["merge"] == 2 + + +def test_get_binding_index(): + @I.ir_module + class BranchAndBind: + @R.function + def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): + x = R.const(1, dtype="int32") + y = R.add(x, x) + if cond: + z = R.multiply(y, y) + else: + z = R.add(y, y) + q = R.add(z, z) + return q + + graph = extract_cfg(BranchAndBind["main"]) + outer_seq = BranchAndBind["main"].body + true_seq = BranchAndBind["main"].body.blocks[0].bindings[2].value.true_branch + false_seq = BranchAndBind["main"].body.blocks[0].bindings[2].value.false_branch + + assert get_binding_index(graph, outer_seq, 0, 0) == 0 + assert get_binding_index(graph, outer_seq, 0, 1) == 1 + assert get_binding_index(graph, outer_seq, 0, 2, match_cond=True) == 2 + assert get_binding_index(graph, true_seq, 0, 0) == 3 + assert get_binding_index(graph, true_seq, 1, 0) == 4 + assert get_binding_index(graph, false_seq, 0, 0) == 5 + assert get_binding_index(graph, false_seq, 1, 0) == 6 + assert get_binding_index(graph, outer_seq, 0, 2) == 7 # the merge + assert get_binding_index(graph, outer_seq, 0, 3) == 8 + assert get_binding_index(graph, outer_seq, 1, 0) == 9 + + +if __name__ == "__main__": + tvm.testing.main()