From 6c5d166a3698a3a2a59f05b547c22fb53a16a908 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 16 Aug 2023 18:50:14 -0400 Subject: [PATCH 01/18] Add control flow graph implementation --- include/tvm/relax/dataflow_analysis.h | 176 +++++++ .../tvm/relax/analysis/dataflow_analysis.py | 133 ++++++ src/relax/analysis/dataflow_analysis.cc | 195 ++++++++ tests/python/relax/test_dataflow_analysis.py | 430 ++++++++++++++++++ 4 files changed, 934 insertions(+) create mode 100644 include/tvm/relax/dataflow_analysis.h create mode 100644 python/tvm/relax/analysis/dataflow_analysis.py create mode 100644 src/relax/analysis/dataflow_analysis.cc create mode 100644 tests/python/relax/test_dataflow_analysis.py diff --git a/include/tvm/relax/dataflow_analysis.h b/include/tvm/relax/dataflow_analysis.h new file mode 100644 index 000000000000..dd3001a1ae55 --- /dev/null +++ b/include/tvm/relax/dataflow_analysis.h @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_analysis.h + * \brief A reusable framework for dataflow analysis in Relax. + * Based on Adrian Sampson's course material: + * https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ + * Do not confuse with dataflow pattern matching (does not use this machinery) + */ + +#ifndef TVM_RELAX_DATAFLOW_ANALYSIS_H_ +#define TVM_RELAX_DATAFLOW_ANALYSIS_H_ + +#include +#include +#include + +namespace tvm { +namespace relax { + +/*! \brief For dataflow analysis, we need to have a graph of basic blocks + * (i.e., a control flow graph). + * The trouble is that Relax's BindingBlocks are not necessarily basic blocks: + * A BindingBlock followed by a DataflowBlock followed by a BindingBlock + * is potentially a single basic blocks, whereas a single BindingBlock that + * contains an If expression may actually comprise multiple basic blocks. + * This representation is a lightweight way of representing basic blocks on top + * of Relax's AST + */ +class BasicBlockNode : public Object { + public: + /*! \brief The SeqExpr the basic block resides in. + * (In normal form, basic blocks cannot span multiple SeqExprs). */ + SeqExpr seq; + + /*! \brief The arguments to the basic block. + * If the basic block is the first in the function, args is the function arguments. + * The basic blocks corresponding to If branches have no arguments. + * The basic block corresponding to the merge point after the If + * will have one argument (corresponding to the merge of the value returned; + * this will be the variable that the If expression is bound to). */ + Array args; + + /*! \brief The final expression evaluated in the basic block. + * If the basic block ends with an If expression, the ret is the If *condition*. + * Otherwise, it will be the value returned by the SeqExpr + * (all other basic blocks will end where the SeqExpr ends).*/ + Expr ret; + + /*! \brief Index of the BindingBlock in the SeqExpr where the basic block starts + * (Convention: If the start_block_idx is past the final index of the SeqExpr, + * that means the basic block contains no bindings.) */ + size_t start_block_idx; + + /*! \brief Index of the binding in the BindingBlock where the basic block starts + * (convention: If the basic block is a merge point, use the index of the binding + * after the If node. Also, if the start_binding_idx is past the final index + * of the block, that means the basic block contains no bindings) */ + size_t start_binding_idx; + + /*! \brief Index of the BindingBlock in the SeqExpr where the basic block ends. + * (convention: If the basic block goes until the end of the SeqExpr, + * end_block_idx will be one _past_ the last index, i.e., seq->blocks.size()) */ + size_t end_block_idx; + + /*! \brief Index of the binding in the BindingBlock where the basic block ends + * (convention: If the end of the basic block is the end of the SeqExpr, + * end_binding_idx will be one _past_ the last idex, i.e., block->bindings.size()) */ + size_t end_binding_idx; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("seq", &seq); + v->Visit("args", &args); + v->Visit("ret", &ret); + v->Visit("start_block_idx", &start_block_idx); + v->Visit("start_binding_idx", &start_binding_idx); + v->Visit("end_block_idx", &end_block_idx); + v->Visit("end_binding_idx", &end_binding_idx); + } + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.analysis.BasicBlock"; + TVM_DECLARE_BASE_OBJECT_INFO(BasicBlockNode, Object); +}; + +/* Representation of a basic block on top of Relax's AST. + */ +class BasicBlock : public ObjectRef { + public: + /*! + * \brief Create a BasicBlock. See the docs on BasicBlockNode for further details. + * + * \param seq: The SeqExpr in which the basic block resides. + * \param args: The arguments to the basic block. + * \param ret: The final expression in the basic block. + * \param start_block_idx: The index of the BindingBlock in the SeqExpr + * where the basic block starts. + * \param start_binding_idx: The index of the binding in the BindingBlock where the + * basic block starts. + * \param end_block_idx: The index of the BindingBlock in the SeqExpr + * where the basic block ends. + * \param end_binding_idx: The index of the binding in the BindingBlock where the + * basic block ends. + */ + TVM_DLL static BasicBlock Create(const SeqExpr& seq, const Array& args, const Expr& ret, + size_t start_block_idx, size_t start_binding_idx, + size_t end_block_idx, size_t end_binding_idx); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BasicBlock, ObjectRef, BasicBlockNode); +}; + +/* A control flow graph corresponding to a function. + */ +class ControlFlowGraphNode : public Object { + public: + /*! \brief The basic blocks in the graph. 0 is the entry point. */ + Array blocks; + /*! \brief The ith member is the list of predecessors (indices) to block i in blocks. */ + Array> preds; + /*! \brief The ith member is the list of successors (indices) to block i in blocks. */ + Array> succs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("blocks", &blocks); + v->Visit("preds", &preds); + v->Visit("succs", &succs); + } + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.analysis.ControlFlowGraph"; + TVM_DECLARE_BASE_OBJECT_INFO(ControlFlowGraphNode, Object); +}; + +class ControlFlowGraph : public ObjectRef { + public: + /*! + * \brief Create a ControlFlowGraph. + * + * \param blocks: The basic blocks corresponding to the graph nodes + * \param preds: List of lists of predecessors to each basic block. + * \param succs: List of lists of successors to each basic block. + */ + TVM_DLL static ControlFlowGraph Create(const Array& blocks, + const Array>& preds, + const Array>& succs); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ControlFlowGraph, ObjectRef, ControlFlowGraphNode); +}; + +/*! + * \brief Extracts the control flow graph for a Relax function. + * \param func The function. This conversion expects it to be normalized. + * \return The control flow graph corresponding to the function. + */ +ControlFlowGraph ExtractCFG(const Function& func); + +} // namespace relax +} // namespace tvm +#endif \ No newline at end of file diff --git a/python/tvm/relax/analysis/dataflow_analysis.py b/python/tvm/relax/analysis/dataflow_analysis.py new file mode 100644 index 000000000000..e88dff908d8f --- /dev/null +++ b/python/tvm/relax/analysis/dataflow_analysis.py @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Python bindings for the dataflow analysis framework +""" + +from typing import List +import tvm +from tvm.ir.base import Node +from tvm.relax.expr import Expr, SeqExpr, Function, Var +from . import _ffi_api + + +@tvm._ffi.register_object("relax.analysis.BasicBlock") +class BasicBlock(Node): + """Representation of a basic block on top of Relax's AST (SeqExprs)""" + + seq: SeqExpr + args: List[Var] + ret: Expr + start_block_idx: int + start_binding_idx: int + end_block_idx: int + end_binding_idx: int + + def __init__( + self, + seq: SeqExpr, + args: List[Var], + ret: Expr, + start_block_idx: int, + start_binding_idx: int, + end_block_idx: int, + end_binding_idx: int, + ): + """ + Create a basic block + + Parameters + ---------- + seq: SeqExpr + The SeqExpr that contains the basic block + (in normal form, no basic block can span across SeqExprs) + + args: List[Var] + The values passed into the block. + The starting block of a function takes in the function args. + Merge blocks (those after an If branch) take the variable + the If expression is bound to. + + ret: Expr + The expression corresponding to the final value produced by a block. + For blocks ending in a branch, the final value is the branch condition. + Otherwise, it is the `body` field of the SeqExpr. + + start_block_idx: int + The index of the block in the SeqExpr's block list where the basic block starts + + start_binding_idx: int + The index of the binding in the starting binding block where the basic block + starts (convention: if the basic block is a merge point, + use the index of the binding after the If node). + """ + return self.__init_handle_by_constructor__( + _ffi_api.BasicBlock, + seq, + args, + ret, + start_block_idx, + start_binding_idx, + end_block_idx, + end_binding_idx, + ) # type: ignore + + +@tvm._ffi.register_object("relax.analysis.ControlFlowGraph") +class ControlFlowGraph(Node): + """Representation of a control flow graph, marking the successors + and predecessors to all basic blocks""" + + def __init__(self, blocks: List[BasicBlock], preds: List[List[int]], succs: List[List[int]]): + """ + Instantiate a control flow graph + + Parameters + ---------- + blocks: List[BasicBlock] + List of basic blocks in the graph + + preds: List[List[int]] + The ith member is the list of predecessors to blocks[i] (given as indices in blocks) + + succs: List[List[int]] + The ith member is the list of successors to blocks[i] (given as indices in blocks) + """ + if len(blocks) != len(preds) or len(blocks) != len(succs): + raise ValueError("The lengths of blocks, preds, and succs must all match.") + + return self.__init_handle_by_constructor__( + _ffi_api.ControlFlowGraph, blocks, preds, succs + ) # type: ignore + + +def ExtractCFG(func: Function) -> ControlFlowGraph: + """ + Given a Relax function, produces the corresponding control flow graph. + The function is expected to have been normalized. + + Parameters + ---------- + func: Function + A Relax function. Must be in normal form. + + Returns + ------- + graph: ControlFlowGraph + Control flow graph corresponding to the function. + """ + return _ffi_api.ExtractCFG(func) # type: ignore diff --git a/src/relax/analysis/dataflow_analysis.cc b/src/relax/analysis/dataflow_analysis.cc new file mode 100644 index 000000000000..5bd8055e8d22 --- /dev/null +++ b/src/relax/analysis/dataflow_analysis.cc @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/dataflow_analysis.cc + * \brief Implementation of functionality in dataflow_analysis.h + */ +#include +#include + +namespace tvm { +namespace relax { + +TVM_REGISTER_NODE_TYPE(BasicBlockNode); + +BasicBlock BasicBlock::Create(const SeqExpr& seq, const Array& args, const Expr& ret, + size_t start_block_idx, size_t start_binding_idx, + size_t end_block_idx, size_t end_binding_idx) { + ObjectPtr n = make_object(); + n->seq = seq; + n->args = args; + n->ret = ret; + n->start_block_idx = start_block_idx; + n->start_binding_idx = start_binding_idx; + n->end_block_idx = end_block_idx; + n->end_binding_idx = end_binding_idx; + return BasicBlock(n); +} + +TVM_REGISTER_NODE_TYPE(ControlFlowGraphNode); + +ControlFlowGraph ControlFlowGraph::Create(const Array& blocks, + const Array>& preds, + const Array>& succs) { + ObjectPtr n = make_object(); + n->blocks = blocks; + n->preds = preds; + n->succs = succs; + return ControlFlowGraph(n); +} + +// Extracts a basic block and updates the running lists blocks, preds, and succs. +// The return value is the index of the final basic block processed in the seq expression +// (useful for processing branches). +size_t ExtractCFGHelper(const SeqExpr& seq, const Array& args, size_t start_block_idx, + size_t start_binding_idx, std::vector current_preds, + std::vector* blocks, std::vector>* preds, + std::vector>* succs) { + size_t end_block_idx = 0; + size_t end_binding_idx = 0; + Expr ret; + Optional branch_var; + Optional branch_expr; + + // go from the start index and continue until we hit the end of the block or a split point + bool hit_branch = false; + // note: if start_block_idx is past seq->blocks.size(), then the loop will not actually run + // and we will not hit a branch, so we will produce a basic block comprised only of the + // seq expr end expression + for (size_t i = start_block_idx; i < seq->blocks.size(); i++) { + for (size_t j = start_binding_idx; j < seq->blocks[i]->bindings.size(); j++) { + Binding binding = seq->blocks[i]->bindings[j]; + if (auto* var_binding = binding.as()) { + if (var_binding->value.as()) { + end_block_idx = i; + end_binding_idx = j; + branch_var = var_binding->var; + branch_expr = Downcast(var_binding->value); + ret = branch_expr.value()->cond; + hit_branch = true; + break; + } + } else if (auto* match_binding = binding.as()) { + if (match_binding->value.as()) { + end_block_idx = i; + end_binding_idx = j; + branch_var = var_binding->var; + branch_expr = Downcast(var_binding->value); + ret = branch_expr.value()->cond; + hit_branch = true; + break; + } + } else { + CHECK(false); // will never happen + } + } + if (hit_branch) { + break; + } + } + + if (!hit_branch) { + end_block_idx = seq->blocks.size(); + end_binding_idx = 0U; // doesn't matter which we use + ret = seq->body; + } + BasicBlock block = BasicBlock::Create(seq, args, ret, start_block_idx, start_binding_idx, + end_block_idx, end_binding_idx); + blocks->push_back(block); + size_t block_idx = blocks->size() - 1U; + succs->push_back({}); + preds->push_back(current_preds); + for (size_t pred : current_preds) { + succs->at(pred).push_back(block_idx); + } + // no branches: then we're done + if (!hit_branch) { + return block_idx; + } + // hit a branch: recurse down the branches and then set up the merge block + SeqExpr true_branch = Downcast(branch_expr.value()->true_branch); + SeqExpr false_branch = Downcast(branch_expr.value()->false_branch); + // the branches could contain their own branches, which is why we return the final block index + size_t end_true = ExtractCFGHelper(true_branch, {}, 0U, 0U, {block_idx}, blocks, preds, succs); + size_t end_false = ExtractCFGHelper(false_branch, {}, 0U, 0U, {block_idx}, blocks, preds, succs); + + // work out the start indices for the merge point + size_t next_start_block_idx = end_block_idx; + size_t next_start_binding_idx = end_binding_idx; + // figure out the next indices + if (end_binding_idx == seq->blocks[end_block_idx]->bindings.size() - 1) { + if (end_block_idx == seq->blocks.size() - 1) { + next_start_block_idx = seq->blocks.size(); + next_start_binding_idx = 0U; + } else { + next_start_block_idx = end_block_idx + 1; + next_start_binding_idx = 0U; + } + } else { + next_start_binding_idx = end_binding_idx + 1; + } + return ExtractCFGHelper(seq, {branch_var.value()}, next_start_block_idx, next_start_binding_idx, + {end_true, end_false}, blocks, preds, succs); +} + +ControlFlowGraph ExtractCFG(const Function& func) { + std::vector blocks; + std::vector> preds; + std::vector> succs; + ExtractCFGHelper(Downcast(func->body), func->params, 0U, 0U, {}, &blocks, &preds, + &succs); + + Array> pred_arr; + for (auto pred_vec : preds) { + Array pred_ints; + for (auto idx : pred_vec) { + pred_ints.push_back(Integer(idx)); + } + pred_arr.push_back(pred_ints); + } + Array> succ_arr; + for (auto succ_vec : succs) { + Array succ_ints; + for (auto idx : succ_vec) { + succ_ints.push_back(Integer(idx)); + } + succ_arr.push_back(succ_ints); + } + return ControlFlowGraph::Create(Array(blocks), pred_arr, succ_arr); +} + +TVM_REGISTER_GLOBAL("relax.analysis.BasicBlock") + .set_body_typed([](const SeqExpr& seq, const Array& args, const Expr& ret, + size_t start_block_idx, size_t start_binding_idx, size_t end_block_idx, + size_t end_binding_idx) { + return BasicBlock::Create(seq, args, ret, start_block_idx, start_binding_idx, end_block_idx, + end_binding_idx); + }); + +TVM_REGISTER_GLOBAL("relax.analysis.ControlFlowGraph") + .set_body_typed([](const Array& blocks, const Array>& preds, + const Array>& succs) { + return ControlFlowGraph::Create(blocks, preds, succs); + }); + +TVM_REGISTER_GLOBAL("relax.analysis.ExtractCFG").set_body_typed(ExtractCFG); + +} // namespace relax +} // namespace tvm \ No newline at end of file diff --git a/tests/python/relax/test_dataflow_analysis.py b/tests/python/relax/test_dataflow_analysis.py new file mode 100644 index 000000000000..5e9879bc3cd9 --- /dev/null +++ b/tests/python/relax/test_dataflow_analysis.py @@ -0,0 +1,430 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List, Optional +import tvm +from tvm import relax +from tvm.relax.analysis.dataflow_analysis import ControlFlowGraph, BasicBlock, ExtractCFG +from tvm.script import ir as I, relax as R +import tvm.testing + + +def assert_pred_succ_lists(graph: ControlFlowGraph, expected_preds: List[List[int]]) -> None: + assert tuple([tuple(preds) for preds in graph.preds]) == tuple( + [tuple(exp_preds) for exp_preds in expected_preds] + ) + + expected_succs = [[] for preds in expected_preds] + # we can automatically invert the predecessor list + # this also guarantees consistency + for i, pred_list in enumerate(expected_preds): + for pred in pred_list: + expected_succs[pred].append(i) + + assert tuple([tuple(succs) for succs in graph.succs]) == tuple( + [tuple(exp_succs) for exp_succs in expected_succs] + ) + + +# common pattern in normalization that we can check for: +# if condition: +# ... +# z = value1 +# else: +# ... +# z = value2 +# +# results in: +# +# VarBinding( +# z, +# If( +# condition, +# SeqExpr([..., BindingBlock([..., VarBinding(new_var1, value1)])], body=new_var1), +# SeqExpr([..., BindingBlock([..., VarBinding(new_var2, value2)])], body=new_var2) +# ) +# ) +# This function can be used for checking the SeqExprs inside the branches +def assert_ret_is_final_binding_in_seq(block: BasicBlock, check_op: Optional[str] = None): + seq_body = block.seq.body + final_binding = block.seq.blocks[-1].bindings[-1] + assert seq_body == final_binding.var + assert block.ret == seq_body + if check_op is not None: + assert isinstance(final_binding.value, relax.Call) + assert final_binding.value.op.name == check_op + + +# ensure that the exprs in each list match each other and that they do not match those in the other lists +def assert_distinct(*groups: List[relax.Expr]): + for i, group in enumerate(groups): + if len(group) == 0: + continue + for item in group[1:]: + assert item == group[0] + for other_group in groups[i + 1 :]: + for item in other_group: + assert group[0] != item + + +def test_trivial_CFG(): + @I.ir_module + class TrivialFunc: + @R.function + def main() -> R.Tensor((), "int32"): + return R.const(1, dtype="int32") + + graph = ExtractCFG(TrivialFunc["main"]) + assert len(graph.blocks) == 1 + assert_pred_succ_lists(graph, [[]]) + assert graph.blocks[0].ret == TrivialFunc["main"].body.body + assert graph.blocks[0].start_block_idx == 0 + assert graph.blocks[0].start_binding_idx == 0 + assert graph.blocks[0].end_block_idx == 0 + + +def test_sequence_of_bindings(): + @I.ir_module + class FuncWithBindings: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + y = R.add(x, x) + z = R.add(y, x) + q = R.multiply(z, x) + return q + + graph = ExtractCFG(FuncWithBindings["main"]) + assert len(graph.blocks) == 1 + assert_pred_succ_lists(graph, [[]]) + assert graph.blocks[0].ret == FuncWithBindings["main"].body.body + assert graph.blocks[0].args == FuncWithBindings["main"].params + assert graph.blocks[0].start_block_idx == 0 + assert graph.blocks[0].start_binding_idx == 0 + assert graph.blocks[0].end_block_idx == 1 + + +def test_dataflow_block(): + @I.ir_module + class FuncWithDataflow: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + y = R.add(x, x) + z = R.add(y, y) + with R.dataflow(): + q = R.multiply(z, z) + r = R.add(q, q) + R.output(r) + s = R.add(r, r) + with R.dataflow(): + t = R.multiply(s, s) + u = R.add(t, t) + R.output(u) + return u + + graph = ExtractCFG(FuncWithDataflow["main"]) + assert len(graph.blocks) == 1 + assert_pred_succ_lists(graph, [[]]) + assert graph.blocks[0].ret == FuncWithDataflow["main"].body.body + assert graph.blocks[0].args == FuncWithDataflow["main"].params + assert graph.blocks[0].start_block_idx == 0 + assert graph.blocks[0].start_binding_idx == 0 + # there are four binding blocks but they form one basic block + assert graph.blocks[0].end_block_idx == 4 + + +def test_simple_branch(): + @I.ir_module + class SimpleBranch: + @R.function + def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): + if cond: + x = R.const(1, dtype="int32") + y = R.add(x, x) + z = R.multiply(y, y) + else: + x = R.const(2, dtype="int32") + y = R.add(x, x) + z = R.multiply(y, y) + return z + + # basic blocks: + # 1. the starting block (no bindings) whose return is the branch condition + # 2. the true branch body (return: R.multiply(y, y)) + # 3. the false branch body (return: R.multiply(y, y)) + # 4. the merge block (no bindings, argument is z) whose return is z + graph = ExtractCFG(SimpleBranch["main"]) + assert len(graph.blocks) == 4 + assert_pred_succ_lists(graph, [[], [0], [0], [1, 2]]) + + assert graph.blocks[0].args == SimpleBranch["main"].params + assert graph.blocks[0].ret == SimpleBranch["main"].params[0] + assert graph.blocks[0].start_block_idx == 0 + assert graph.blocks[0].start_binding_idx == 0 + assert graph.blocks[0].end_block_idx == 0 + assert graph.blocks[0].end_binding_idx == 0 + + assert len(graph.blocks[1].args) == 0 + assert_ret_is_final_binding_in_seq(graph.blocks[1], "relax.multiply") + assert graph.blocks[1].start_block_idx == 0 + assert graph.blocks[1].start_binding_idx == 0 + assert graph.blocks[1].end_block_idx == 1 + + assert len(graph.blocks[2].args) == 0 + assert_ret_is_final_binding_in_seq(graph.blocks[2], "relax.multiply") + assert graph.blocks[2].start_block_idx == 0 + assert graph.blocks[2].start_binding_idx == 0 + assert graph.blocks[2].end_block_idx == 1 + + assert len(graph.blocks[3].args) == 1 + assert graph.blocks[3].args[0].name_hint == "z" + assert graph.blocks[3].ret == SimpleBranch["main"].body.body + # the if was the last binding in the block, so we're past the end + assert graph.blocks[3].start_block_idx == 1 + assert graph.blocks[3].end_block_idx == 1 + + assert_distinct( + [graph.blocks[0].seq, graph.blocks[3].seq], [graph.blocks[1]], [graph.blocks[2]] + ) + + +def test_bindings_after_branch(): + @I.ir_module + class BranchAndBind: + @R.function + def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): + x = R.const(1, dtype="int32") + y = R.add(x, x) + if cond: + z = R.multiply(y, y) + else: + z = R.add(y, y) + q = R.add(z, z) + return q + + graph = ExtractCFG(BranchAndBind["main"]) + assert len(graph.blocks) == 4 + assert_pred_succ_lists(graph, [[], [0], [0], [1, 2]]) + + # same as above example, except there are bindings preceding the if (included in block 0) + # and after the if (included in block 3) + + assert graph.blocks[0].args == BranchAndBind["main"].params + assert graph.blocks[0].ret == BranchAndBind["main"].params[0] + assert graph.blocks[0].start_block_idx == 0 + assert graph.blocks[0].start_binding_idx == 0 + assert graph.blocks[0].end_block_idx == 0 + assert graph.blocks[0].end_binding_idx == 2 + + assert len(graph.blocks[1].args) == 0 + assert_ret_is_final_binding_in_seq(graph.blocks[1], "relax.multiply") + assert graph.blocks[1].start_block_idx == 0 + assert graph.blocks[1].start_binding_idx == 0 + assert graph.blocks[1].end_block_idx == 1 + + assert len(graph.blocks[2].args) == 0 + assert_ret_is_final_binding_in_seq(graph.blocks[2], "relax.add") + assert graph.blocks[2].start_block_idx == 0 + assert graph.blocks[2].start_binding_idx == 0 + assert graph.blocks[2].end_block_idx == 1 + + assert len(graph.blocks[3].args) == 1 + assert graph.blocks[3].args[0].name_hint == "z" + assert graph.blocks[3].ret.name_hint == "q" + assert graph.blocks[3].start_block_idx == 0 + assert graph.blocks[3].start_binding_idx == 3 + assert graph.blocks[3].end_block_idx == 1 + assert graph.blocks[3].end_binding_idx == 0 + + assert_distinct( + [graph.blocks[0].seq, graph.blocks[3].seq], [graph.blocks[1]], [graph.blocks[2]] + ) + + +def test_branch_with_multiple_blocks(): + @I.ir_module + class LongBranches: + @R.function + def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): + if cond: + x = R.const(1, dtype="int32") + y = R.add(x, x) + with R.dataflow(): + z = R.multiply(y, y) + w = R.add(z, z) + v = R.multiply(w, w) + R.output(v) + q = R.add(v, v) + r = R.multiply(q, q) + else: + x = R.const(2, dtype="int32") + y = R.multiply(x, x) + with R.dataflow(): + z = R.add(y, y) + w = R.multiply(z, z) + v = R.add(w, w) + R.output(v) + q = R.multiply(v, v) + r = R.add(q, q) + return r + + graph = ExtractCFG(LongBranches["main"]) + # empty entry block, one block for each branch, and an empty exit block + assert len(graph.blocks) == 4 + assert_pred_succ_lists(graph, [[], [0], [0], [1, 2]]) + + assert graph.blocks[0].args == LongBranches["main"].params + assert graph.blocks[0].ret == LongBranches["main"].params[0] + assert graph.blocks[0].start_block_idx == 0 + assert graph.blocks[0].start_binding_idx == 0 + assert graph.blocks[0].end_block_idx == 0 + assert graph.blocks[0].end_binding_idx == 0 + + # there are 3 binding blocks included in each branch + assert len(graph.blocks[1].args) == 0 + assert_ret_is_final_binding_in_seq(graph.blocks[1], "relax.multiply") + assert graph.blocks[1].start_block_idx == 0 + assert graph.blocks[1].start_binding_idx == 0 + assert graph.blocks[1].end_block_idx == 3 + + assert len(graph.blocks[2].args) == 0 + assert_ret_is_final_binding_in_seq(graph.blocks[2], "relax.add") + assert graph.blocks[2].start_block_idx == 0 + assert graph.blocks[2].start_binding_idx == 0 + assert graph.blocks[2].end_block_idx == 3 + + assert len(graph.blocks[3].args) == 1 + assert graph.blocks[3].args[0].name_hint == "r" + assert graph.blocks[3].ret.name_hint == "r" + assert graph.blocks[3].start_block_idx == 1 + assert graph.blocks[3].end_block_idx == 1 + + assert_distinct( + [graph.blocks[0].seq, graph.blocks[3].seq], [graph.blocks[1]], [graph.blocks[2]] + ) + + +def test_nested_branches(): + @I.ir_module + class NestedBranches: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + cond1 = R.const(True, dtype="bool") + if cond1: + cond2 = R.const(False, dtype="bool") + if cond2: + y = R.add(x, x) + else: + y = R.multiply(x, x) + z = R.add(y, y) + else: + cond3 = R.const(True, dtype="bool") + if cond3: + y = R.multiply(x, x) + else: + y = R.add(x, x) + z = R.multiply(y, y) + return z + + graph = ExtractCFG(NestedBranches["main"]) + # basic blocks: entry block to func, entry block to true branch, true branch in true branch, + # false branch in true branch, merge block in true branch, + # entry to false branch, true branch in false branch, false branch in false branch, + # merge block in false branch, merge block in outer function + assert len(graph.blocks) == 10 + assert_pred_succ_lists( + graph, + [ + [], # function entry + [0], # true branch entry + [1], # true branch's true branch + [1], # true branch's false branch + [2, 3], # true branch's exit + [0], # false branch entry + [5], # false branch's true branch + [5], # false branch's false branch + [6, 7], # false branch exit + [4, 8], # function exit + ], + ) + + assert graph.blocks[0].args == NestedBranches["main"].params + assert graph.blocks[0].ret.name_hint == "cond1" + assert graph.blocks[0].start_block_idx == 0 + assert graph.blocks[0].start_binding_idx == 0 + assert graph.blocks[0].end_block_idx == 0 + assert graph.blocks[0].end_binding_idx == 1 + + assert len(graph.blocks[1].args) == 0 + assert graph.blocks[1].ret.name_hint == "cond2" + assert graph.blocks[1].start_block_idx == 0 + assert graph.blocks[1].start_binding_idx == 0 + assert graph.blocks[1].end_block_idx == 0 + assert graph.blocks[1].end_binding_idx == 1 + + assert len(graph.blocks[2].args) == 0 + assert_ret_is_final_binding_in_seq(graph.blocks[2], "relax.add") + assert graph.blocks[2].start_block_idx == 0 + assert graph.blocks[2].start_binding_idx == 0 + assert graph.blocks[2].end_block_idx == 1 + + assert len(graph.blocks[3].args) == 0 + assert_ret_is_final_binding_in_seq(graph.blocks[3], "relax.multiply") + assert graph.blocks[3].start_block_idx == 0 + assert graph.blocks[3].start_binding_idx == 0 + assert graph.blocks[3].end_block_idx == 1 + + assert len(graph.blocks[4].args) == 1 + assert graph.blocks[4].args[0].name_hint == "y" + assert_ret_is_final_binding_in_seq(graph.blocks[4], "relax.add") + assert graph.blocks[4].start_block_idx == 0 + assert graph.blocks[4].start_binding_idx == 2 + assert graph.blocks[4].end_block_idx == 1 + + assert len(graph.blocks[5].args) == 0 + assert graph.blocks[5].ret.name_hint == "cond3" + assert graph.blocks[5].start_block_idx == 0 + assert graph.blocks[5].start_binding_idx == 0 + assert graph.blocks[5].end_block_idx == 0 + assert graph.blocks[5].end_binding_idx == 1 + + assert len(graph.blocks[6].args) == 0 + assert_ret_is_final_binding_in_seq(graph.blocks[6], "relax.multiply") + assert graph.blocks[6].start_block_idx == 0 + assert graph.blocks[6].start_binding_idx == 0 + assert graph.blocks[6].end_block_idx == 1 + + assert len(graph.blocks[7].args) == 0 + assert_ret_is_final_binding_in_seq(graph.blocks[7], "relax.add") + assert graph.blocks[7].start_block_idx == 0 + assert graph.blocks[7].start_binding_idx == 0 + assert graph.blocks[7].end_block_idx == 1 + + assert len(graph.blocks[8].args) == 1 + assert graph.blocks[8].args[0].name_hint == "y" + assert_ret_is_final_binding_in_seq(graph.blocks[8], "relax.multiply") + assert graph.blocks[8].start_block_idx == 0 + assert graph.blocks[8].start_binding_idx == 2 + assert graph.blocks[8].end_block_idx == 1 + + assert len(graph.blocks[9].args) == 1 + assert graph.blocks[9].args[0].name_hint == "z" + assert graph.blocks[9].ret.name_hint == "z" + assert graph.blocks[9].start_block_idx == 1 + assert graph.blocks[9].end_block_idx == 1 + + +if __name__ == "__main__": + tvm.testing.main() From f4f5de11b1f4f0535cfb771a18774800ae30fc92 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 17 Aug 2023 17:09:26 -0400 Subject: [PATCH 02/18] Implement dataflow analysis framework and add tests --- include/tvm/relax/dataflow_analysis.h | 28 +++++ .../tvm/relax/analysis/dataflow_analysis.py | 50 +++++++- src/relax/analysis/dataflow_analysis.cc | 52 +++++++++ tests/python/relax/test_dataflow_analysis.py | 107 +++++++++++++++++- 4 files changed, 234 insertions(+), 3 deletions(-) diff --git a/include/tvm/relax/dataflow_analysis.h b/include/tvm/relax/dataflow_analysis.h index dd3001a1ae55..2b5481ad07d1 100644 --- a/include/tvm/relax/dataflow_analysis.h +++ b/include/tvm/relax/dataflow_analysis.h @@ -171,6 +171,34 @@ class ControlFlowGraph : public ObjectRef { */ ControlFlowGraph ExtractCFG(const Function& func); +/*! + * \brief Generic implementation of dataflow analysis, based on + * Adrian Sampson's course material: + * https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ + * + * The analysis creates input and output maps (mapping basic block indices to a domain), + * sets the initial input and output for each basic block to the init value, and then + * performs a traversal of the CFG (BFS in this implementation, since unlike the general case, + * we do not have loops) and uses the transfer and merge function to update the inputs and + * outputs. The analysis can proceed forwards (from block 0 onwards) or backwards (from the last + * block back), flipping the roles of the input and output maps in the cases. + * + * \param forward Whether to perform a forward or backward analysis + * \param cfg The input control flow graph + * \param init The value corresponding to an initial domain + * \param transfer_func Given an input domain and a basic block, determine the resulting domain + * \param merge_func Given a set of domains, combine them to form a single new domain + * (note: in Relax, a basic block can never have more than two predecessors/successors) + * + * \return Two arrays, the first being the "input map" (domain being passed *into* + * each basic block in the CFG) and the second being the "output map" (the domain + * being passed *out of* the corresponding basic block) + */ +std::pair, Array> DataflowAnalysis( + const ControlFlowGraph& cfg, const ObjectRef& init, + std::function transfer_func, + std::function merge_func, bool forward = true); + } // namespace relax } // namespace tvm #endif \ No newline at end of file diff --git a/python/tvm/relax/analysis/dataflow_analysis.py b/python/tvm/relax/analysis/dataflow_analysis.py index e88dff908d8f..b1e2ba66e592 100644 --- a/python/tvm/relax/analysis/dataflow_analysis.py +++ b/python/tvm/relax/analysis/dataflow_analysis.py @@ -18,7 +18,7 @@ Python bindings for the dataflow analysis framework """ -from typing import List +from typing import Any, Callable, List, Tuple import tvm from tvm.ir.base import Node from tvm.relax.expr import Expr, SeqExpr, Function, Var @@ -131,3 +131,51 @@ def ExtractCFG(func: Function) -> ControlFlowGraph: Control flow graph corresponding to the function. """ return _ffi_api.ExtractCFG(func) # type: ignore + + +def DataflowAnalysis( + cfg: ControlFlowGraph, + init: Any, + transfer_func: Callable[[BasicBlock, Any], Any], + merge_func: Callable[[Any, Any], Any], + forward: bool = True, +) -> Tuple[List[Any], List[Any]]: + """ + Generic dataflow analysis framework, based on Adrian Sampson's course notes: + https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ + + The analysis creates input and output maps (mapping basic block indices to a domain), + sets the initial input and output for each basic block to the init value, and then + performs a traversal of the CFG (BFS in this implementation, since unlike the general case, + we do not have loops) and uses the transfer and merge function to update the inputs and + outputs. The analysis can proceed forwards (from block 0 onwards) or backwards (from the last + block back), flipping the roles of the input and output maps in the cases. + + Parameters + ---------- + cfg: ControlFlowGraph + The input control flow graph + + init: Any + The initial value in the analysis domain to which all blocks should be initialized. + + transfer_func: Callable[[BasicBlock, Any], Any] + Given a basic block and the input domain, compute the new output domain. + + merge_func: Callable[[Any, Any], Any] + When two output domains are fed into a single block (i.e., after an If branch), + the merge function is used to combine them into a single domain. + + forward: bool + If true, the analysis proceeds forwards (starting from block 0 and going onwards). + If false, the analysis proceeds backwards (starting from the last block and going back). + The input and output maps play the opposite roles in forward and backward analyses. + I.e., in a backward analysis, the "final output" is the input map entry for block 0 + and the initial input is the output map entry for the last block. + + Returns + ------- + ret: Tuple[List[Any], List[Any]] + A pair of the final input and output maps + """ + return _ffi_api.DataflowAnalysis(forward, cfg, init, transfer_func, merge_func) # type: ignore diff --git a/src/relax/analysis/dataflow_analysis.cc b/src/relax/analysis/dataflow_analysis.cc index 5bd8055e8d22..ab5c84f8011c 100644 --- a/src/relax/analysis/dataflow_analysis.cc +++ b/src/relax/analysis/dataflow_analysis.cc @@ -24,6 +24,8 @@ #include #include +#include + namespace tvm { namespace relax { @@ -175,6 +177,49 @@ ControlFlowGraph ExtractCFG(const Function& func) { return ControlFlowGraph::Create(Array(blocks), pred_arr, succ_arr); } +std::pair, Array> DataflowAnalysis( + const ControlFlowGraph& cfg, const ObjectRef& init, + std::function transfer_func, + std::function merge_func, bool forward) { + std::vector in_map; + std::vector out_map; + for (size_t i = 0; i < cfg->blocks.size(); i++) { + in_map.push_back(init); + out_map.push_back(init); + } + + // Modification from Adrian Sampson's version: + // Since there are no loops in our AST, one traversal through the CFG suffices. + // We will do BFS + std::queue worklist; + worklist.push((forward) ? 0 : cfg->blocks.size() - 1); + while (!worklist.empty()) { + size_t idx = worklist.front(); + worklist.pop(); + Array prev = (forward) ? cfg->preds[idx] : cfg->succs[idx]; + Array next = (forward) ? cfg->succs[idx] : cfg->preds[idx]; + std::vector* results = (forward) ? &out_map : &in_map; + std::vector* inputs = (forward) ? &in_map : &out_map; + + // Cases (for forward analysis): + // 0 predecessors: The first block in the function + // 1 predecessor: A branch in an If node (no merge needed) + // 2 predecessors: The merge block after an If node (merge needed) + // (Analogous for successors in backward analysis) + inputs->operator[](idx) = (prev.size() == 0) ? init + : (prev.size() == 1) ? results->at(prev[0].IntValue()) + : merge_func(results->at(prev[0].IntValue()), + results->at(prev[1].IntValue())); + results->operator[](idx) = transfer_func(cfg->blocks[idx], inputs->at(idx)); + + for (Integer next_idx : next) { + worklist.push(next_idx.IntValue()); + } + } + + return {Array(in_map), Array(out_map)}; +} + TVM_REGISTER_GLOBAL("relax.analysis.BasicBlock") .set_body_typed([](const SeqExpr& seq, const Array& args, const Expr& ret, size_t start_block_idx, size_t start_binding_idx, size_t end_block_idx, @@ -191,5 +236,12 @@ TVM_REGISTER_GLOBAL("relax.analysis.ControlFlowGraph") TVM_REGISTER_GLOBAL("relax.analysis.ExtractCFG").set_body_typed(ExtractCFG); +TVM_REGISTER_GLOBAL("relax.analysis.DataflowAnalysis") + .set_body_typed([](const ControlFlowGraph& cfg, const ObjectRef& init, PackedFunc transfer_func, + PackedFunc merge_func, bool forward) { + auto ret = DataflowAnalysis(cfg, init, transfer_func, merge_func, forward); + return Array({ret.first, ret.second}); + }); + } // namespace relax } // namespace tvm \ No newline at end of file diff --git a/tests/python/relax/test_dataflow_analysis.py b/tests/python/relax/test_dataflow_analysis.py index 5e9879bc3cd9..5013a94674b4 100644 --- a/tests/python/relax/test_dataflow_analysis.py +++ b/tests/python/relax/test_dataflow_analysis.py @@ -14,10 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List, Optional +from typing import Any, Callable, List, Optional import tvm from tvm import relax -from tvm.relax.analysis.dataflow_analysis import ControlFlowGraph, BasicBlock, ExtractCFG +from tvm.relax.analysis.dataflow_analysis import ( + ControlFlowGraph, + BasicBlock, + ExtractCFG, + DataflowAnalysis, +) from tvm.script import ir as I, relax as R import tvm.testing @@ -426,5 +431,103 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): assert graph.blocks[9].end_block_idx == 1 +def test_simple_analysis(): + @I.ir_module + class TrivialFunc: + @R.function + def main() -> R.Tensor((), "int32"): + return R.const(1, dtype="int32") + + # only one basic block to consider here + init = {"a": 1} + + def transfer_func(_, domain): + # the input domain will be converted into an immutable TVM Map, + # so we have to create a new domain + new_domain = {} + for k, v in domain.items(): + new_domain[k] = v + new_domain["b"] = 2 + return new_domain + + # there will not be a merge here + merge_func = lambda domain1, _: domain1 + + def check_expected_maps(in_map, out_map): + # we expect the in map to be the init value and the out map to have the key b + assert len(in_map[0]) == 1 + assert in_map[0]["a"] == 1 + assert len(out_map[0]) == 2 + assert out_map[0]["a"] == 1 + assert out_map[0]["b"] == 2 + + cfg = ExtractCFG(TrivialFunc["main"]) + in_map, out_map = DataflowAnalysis(cfg, init, transfer_func, merge_func, forward=True) + check_expected_maps(in_map, out_map) + # backward will just flip in and out + in_map, out_map = DataflowAnalysis(cfg, init, transfer_func, merge_func, forward=False) + check_expected_maps(out_map, in_map) + + +def test_simple_analysis_with_merge(): + @I.ir_module + class SimpleBranch: + @R.function + def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): + if cond: + x = R.const(1, dtype="int32") + y = R.add(x, x) + z = R.multiply(y, y) + else: + x = R.const(2, dtype="int32") + y = R.add(x, x) + z = R.multiply(y, y) + return z + + init = {"a": 1} + + def transfer_func(_, domain): + new_domain = {} + for k, v in domain.items(): + new_domain[k] = v + 1 + return new_domain + + def merge_func(domain1, domain2): + new_domain = {} + for k, v in domain1.items(): + new_domain[k] = v + for k, v in domain2.items(): + if k not in new_domain or (k in new_domain and new_domain[k] < v): + new_domain[k] = v + if "merge" not in new_domain: + new_domain["merge"] = 1 + return new_domain + + def check_expected_maps(in_map, out_map, forward=True): + # merge will happen in the last block only + i = 0 if forward else 3 + assert len(in_map[i]) == 1 and len(out_map[i]) == 1 + assert in_map[i]["a"] == 1 + assert out_map[i]["a"] == 2 + + for j in (1, 2): + assert len(in_map[j]) == 1 and len(out_map[j]) == 1 + assert in_map[j]["a"] == 2 + assert out_map[j]["a"] == 3 + + i = 3 if forward else 0 + assert len(in_map[i]) == 2 and len(out_map[i]) == 2 + assert in_map[i]["a"] == 3 + assert in_map[i]["merge"] == 1 + assert out_map[i]["a"] == 4 + assert out_map[i]["merge"] == 2 + + cfg = ExtractCFG(SimpleBranch["main"]) + in_map, out_map = DataflowAnalysis(cfg, init, transfer_func, merge_func, forward=True) + check_expected_maps(in_map, out_map, forward=True) + in_map, out_map = DataflowAnalysis(cfg, init, transfer_func, merge_func, forward=False) + check_expected_maps(out_map, in_map, forward=False) + + if __name__ == "__main__": tvm.testing.main() From a3b6657e6af9065c464883584f1d9c3e2693939b Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 17 Aug 2023 17:19:40 -0400 Subject: [PATCH 03/18] Correct doc comment --- src/relax/analysis/dataflow_analysis.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relax/analysis/dataflow_analysis.cc b/src/relax/analysis/dataflow_analysis.cc index ab5c84f8011c..6a8f6721e0b5 100644 --- a/src/relax/analysis/dataflow_analysis.cc +++ b/src/relax/analysis/dataflow_analysis.cc @@ -18,7 +18,7 @@ */ /*! - * \file tvm/relax/dataflow_analysis.cc + * \file tvm/relax/analysis/dataflow_analysis.cc * \brief Implementation of functionality in dataflow_analysis.h */ #include From 26b0d21885b978ec17fe9e466370d30397865337 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 31 Aug 2023 14:21:03 -0400 Subject: [PATCH 04/18] Phrase dataflow analysis per binding instead of per basic block --- include/tvm/relax/dataflow_analysis.h | 159 +++--- .../tvm/relax/analysis/dataflow_analysis.py | 120 ++--- src/relax/analysis/dataflow_analysis.cc | 198 ++++---- tests/python/relax/test_dataflow_analysis.py | 472 ++++++++---------- 4 files changed, 429 insertions(+), 520 deletions(-) diff --git a/include/tvm/relax/dataflow_analysis.h b/include/tvm/relax/dataflow_analysis.h index 2b5481ad07d1..f8782e9a89ad 100644 --- a/include/tvm/relax/dataflow_analysis.h +++ b/include/tvm/relax/dataflow_analysis.h @@ -35,110 +35,92 @@ namespace tvm { namespace relax { -/*! \brief For dataflow analysis, we need to have a graph of basic blocks - * (i.e., a control flow graph). - * The trouble is that Relax's BindingBlocks are not necessarily basic blocks: - * A BindingBlock followed by a DataflowBlock followed by a BindingBlock - * is potentially a single basic blocks, whereas a single BindingBlock that - * contains an If expression may actually comprise multiple basic blocks. - * This representation is a lightweight way of representing basic blocks on top - * of Relax's AST +/*! \brief For dataflow analysis, we need to have a control flow graph. + * We will organize this graphs by bindings, which allows analyses to + * state their results for each binding in a SeqExpr. + * + * There are a few cases that have to be handled: + * 1. A normal binding (most common)ICHECK + * 2. The condition expression in an If node (a "split" point) + * 3. A merge point (the variable to which an If node is bound: it is a "merge" between + * the SeqExprs in the true and false branches) + * 4. The body expression in a SeqExpr (not actually bound) */ -class BasicBlockNode : public Object { +enum BindingNodeKind : int { + kBinding = 0, + kIfCond = 1, + kIfMerge = 2, + kSeqBody = 3 +}; + +class GraphBindingNode : public Object { public: - /*! \brief The SeqExpr the basic block resides in. - * (In normal form, basic blocks cannot span multiple SeqExprs). */ + /*! \brief The SeqExpr the binding resides in. */ SeqExpr seq; - /*! \brief The arguments to the basic block. - * If the basic block is the first in the function, args is the function arguments. - * The basic blocks corresponding to If branches have no arguments. - * The basic block corresponding to the merge point after the If - * will have one argument (corresponding to the merge of the value returned; - * this will be the variable that the If expression is bound to). */ + /*! \brief The arguments to the binding. Only the first binding in the graph has arguments + * (i.e., the function arguments). */ Array args; - /*! \brief The final expression evaluated in the basic block. - * If the basic block ends with an If expression, the ret is the If *condition*. - * Otherwise, it will be the value returned by the SeqExpr - * (all other basic blocks will end where the SeqExpr ends).*/ - Expr ret; - - /*! \brief Index of the BindingBlock in the SeqExpr where the basic block starts - * (Convention: If the start_block_idx is past the final index of the SeqExpr, - * that means the basic block contains no bindings.) */ - size_t start_block_idx; - - /*! \brief Index of the binding in the BindingBlock where the basic block starts - * (convention: If the basic block is a merge point, use the index of the binding - * after the If node. Also, if the start_binding_idx is past the final index - * of the block, that means the basic block contains no bindings) */ - size_t start_binding_idx; - - /*! \brief Index of the BindingBlock in the SeqExpr where the basic block ends. - * (convention: If the basic block goes until the end of the SeqExpr, - * end_block_idx will be one _past_ the last index, i.e., seq->blocks.size()) */ - size_t end_block_idx; - - /*! \brief Index of the binding in the BindingBlock where the basic block ends - * (convention: If the end of the basic block is the end of the SeqExpr, - * end_binding_idx will be one _past_ the last idex, i.e., block->bindings.size()) */ - size_t end_binding_idx; + /*! \brief Index of the binding block in the SeqExpr where the binding is found. + * Convention: We put the SeqExpr body at one block past the final block. */ + size_t block_idx; + + /*! \brief Index of the binding within the binding block corresponding to this binding. + * Convention: Both the If condition and merge are mapped to the same index. + * We use the kind to distinguish. */ + size_t binding_idx; + + /*! \brief The kind of binding this is. */ + BindingNodeKind kind; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("seq", &seq); v->Visit("args", &args); - v->Visit("ret", &ret); - v->Visit("start_block_idx", &start_block_idx); - v->Visit("start_binding_idx", &start_binding_idx); - v->Visit("end_block_idx", &end_block_idx); - v->Visit("end_binding_idx", &end_binding_idx); + v->Visit("block_idx", &block_idx); + v->Visit("binding_idx", &binding_idx); + v->Visit("kind", &kind); } static constexpr const uint32_t _type_index = TypeIndex::kDynamic; - static constexpr const char* _type_key = "relax.analysis.BasicBlock"; - TVM_DECLARE_BASE_OBJECT_INFO(BasicBlockNode, Object); + static constexpr const char* _type_key = "relax.analysis.GraphBinding"; + TVM_DECLARE_BASE_OBJECT_INFO(GraphBindingNode, Object); }; -/* Representation of a basic block on top of Relax's AST. - */ -class BasicBlock : public ObjectRef { +/*! \brief Representation of a binding in the control flow graph */ +class GraphBinding : public ObjectRef { public: /*! - * \brief Create a BasicBlock. See the docs on BasicBlockNode for further details. + * \brief Create a GraphBinding. See the docs on GraphBindingNode for further details. * - * \param seq: The SeqExpr in which the basic block resides. - * \param args: The arguments to the basic block. - * \param ret: The final expression in the basic block. - * \param start_block_idx: The index of the BindingBlock in the SeqExpr - * where the basic block starts. - * \param start_binding_idx: The index of the binding in the BindingBlock where the - * basic block starts. - * \param end_block_idx: The index of the BindingBlock in the SeqExpr - * where the basic block ends. - * \param end_binding_idx: The index of the binding in the BindingBlock where the - * basic block ends. + * \param seq: The SeqExpr in which the binding resides. + * \param args: The arguments to the binding (only nonempty for the first binding: + * these will be the function arguments) + * \param block_idx: The index of the BindingBlock in the SeqExpr + * where the binding resides (for the return expression, use one past the final block). + * \param binding_idx: The index of the binding in the BindingBlock corresponding to the binding. + * \param kind: The kind of binding this is. (Used especially to distinguish If node conditions + * from the merge after the If) */ - TVM_DLL static BasicBlock Create(const SeqExpr& seq, const Array& args, const Expr& ret, - size_t start_block_idx, size_t start_binding_idx, - size_t end_block_idx, size_t end_binding_idx); + TVM_DLL static GraphBinding Create(const SeqExpr& seq, const Array& args, size_t block_idx, + size_t binding_idx, BindingNodeKind kind); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BasicBlock, ObjectRef, BasicBlockNode); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(GraphBinding, ObjectRef, GraphBindingNode); }; /* A control flow graph corresponding to a function. */ class ControlFlowGraphNode : public Object { public: - /*! \brief The basic blocks in the graph. 0 is the entry point. */ - Array blocks; - /*! \brief The ith member is the list of predecessors (indices) to block i in blocks. */ + /*! \brief The bindings in the graph. 0 is the entry point. */ + Array bindings; + /*! \brief The ith member is the list of predecessors (indices) to binding i in bindings. */ Array> preds; - /*! \brief The ith member is the list of successors (indices) to block i in blocks. */ + /*! \brief The ith member is the list of successors (indices) to binding i in bindings. */ Array> succs; void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("blocks", &blocks); + v->Visit("bindings", &bindings); v->Visit("preds", &preds); v->Visit("succs", &succs); } @@ -153,11 +135,11 @@ class ControlFlowGraph : public ObjectRef { /*! * \brief Create a ControlFlowGraph. * - * \param blocks: The basic blocks corresponding to the graph nodes - * \param preds: List of lists of predecessors to each basic block. - * \param succs: List of lists of successors to each basic block. + * \param bindings: The bindings in the graph + * \param preds: List of lists of predecessors to each binding. + * \param succs: List of lists of successors to each binding. */ - TVM_DLL static ControlFlowGraph Create(const Array& blocks, + TVM_DLL static ControlFlowGraph Create(const Array& bindings, const Array>& preds, const Array>& succs); @@ -173,30 +155,31 @@ ControlFlowGraph ExtractCFG(const Function& func); /*! * \brief Generic implementation of dataflow analysis, based on - * Adrian Sampson's course material: + * Adrian Sampson's course material, except binding by binding + * instead of basic block by basic block: * https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ * - * The analysis creates input and output maps (mapping basic block indices to a domain), - * sets the initial input and output for each basic block to the init value, and then + * The analysis creates input and output maps (mapping binding indices to a domain), + * sets the initial input and output for each binding to the init value, and then * performs a traversal of the CFG (BFS in this implementation, since unlike the general case, * we do not have loops) and uses the transfer and merge function to update the inputs and - * outputs. The analysis can proceed forwards (from block 0 onwards) or backwards (from the last - * block back), flipping the roles of the input and output maps in the cases. + * outputs. The analysis can proceed forwards (from binding 0 onwards) or backwards (from the + * last binding back), flipping the roles of the input and output maps in the cases. * * \param forward Whether to perform a forward or backward analysis * \param cfg The input control flow graph * \param init The value corresponding to an initial domain - * \param transfer_func Given an input domain and a basic block, determine the resulting domain + * \param transfer_func Given an input domain and a binding, determine the resulting domain * \param merge_func Given a set of domains, combine them to form a single new domain - * (note: in Relax, a basic block can never have more than two predecessors/successors) + * (note: in Relax, a binding can never have more than two predecessors/successors) * * \return Two arrays, the first being the "input map" (domain being passed *into* - * each basic block in the CFG) and the second being the "output map" (the domain - * being passed *out of* the corresponding basic block) + * each binding in the CFG) and the second being the "output map" (the domain + * being passed *out of* the corresponding binding) */ std::pair, Array> DataflowAnalysis( const ControlFlowGraph& cfg, const ObjectRef& init, - std::function transfer_func, + std::function transfer_func, std::function merge_func, bool forward = true); } // namespace relax diff --git a/python/tvm/relax/analysis/dataflow_analysis.py b/python/tvm/relax/analysis/dataflow_analysis.py index b1e2ba66e592..e942e60f855e 100644 --- a/python/tvm/relax/analysis/dataflow_analysis.py +++ b/python/tvm/relax/analysis/dataflow_analysis.py @@ -17,7 +17,7 @@ """ Python bindings for the dataflow analysis framework """ - +from enum import Enum from typing import Any, Callable, List, Tuple import tvm from tvm.ir.base import Node @@ -25,65 +25,62 @@ from . import _ffi_api -@tvm._ffi.register_object("relax.analysis.BasicBlock") -class BasicBlock(Node): - """Representation of a basic block on top of Relax's AST (SeqExprs)""" +class BindingNodeKind(Enum): + kBinding = 0 + kIfCond = 1 + kIfMerge = 2 + kSeqBody = 3 + + +@tvm._ffi.register_object("relax.analysis.GraphBinding") +class GraphBinding(Node): + """Representation of a binding in a control flow graph""" seq: SeqExpr args: List[Var] - ret: Expr - start_block_idx: int - start_binding_idx: int - end_block_idx: int - end_binding_idx: int + block_idx: int + binding_idx: int + kind: BindingNodeKind def __init__( self, seq: SeqExpr, args: List[Var], - ret: Expr, - start_block_idx: int, - start_binding_idx: int, - end_block_idx: int, - end_binding_idx: int, + block_idx: int, + binding_idx: int, + kind: BindingNodeKind, ): """ - Create a basic block + Create a graph binding Parameters ---------- seq: SeqExpr - The SeqExpr that contains the basic block - (in normal form, no basic block can span across SeqExprs) + The SeqExpr that contains the binding args: List[Var] - The values passed into the block. - The starting block of a function takes in the function args. - Merge blocks (those after an If branch) take the variable - the If expression is bound to. - - ret: Expr - The expression corresponding to the final value produced by a block. - For blocks ending in a branch, the final value is the branch condition. - Otherwise, it is the `body` field of the SeqExpr. - - start_block_idx: int - The index of the block in the SeqExpr's block list where the basic block starts - - start_binding_idx: int - The index of the binding in the starting binding block where the basic block - starts (convention: if the basic block is a merge point, - use the index of the binding after the If node). + Arguments taken by the binding (only used for the entry binding: + these will be the function arguments. Otherwise, this array should be empty.) + + block_idx: int + The index of the block in the SeqExpr's block list where the binding resides + (convention: for the SeqExpr body, we will use one past the final block) + + binding_idx: int + The index of the binding in the binding block corresponding to this binding. + + kind: BindingNodeKind + The kind of binding. We distinguish between ordinary bindings, + If conditions, If merges (the var bound to the result of the If node), + and the body of the SeqExpr. """ return self.__init_handle_by_constructor__( - _ffi_api.BasicBlock, + _ffi_api.GraphBinding, seq, args, - ret, - start_block_idx, - start_binding_idx, - end_block_idx, - end_binding_idx, + block_idx, + binding_idx, + kind, ) # type: ignore @@ -92,26 +89,28 @@ class ControlFlowGraph(Node): """Representation of a control flow graph, marking the successors and predecessors to all basic blocks""" - def __init__(self, blocks: List[BasicBlock], preds: List[List[int]], succs: List[List[int]]): + def __init__( + self, bindings: List[GraphBinding], preds: List[List[int]], succs: List[List[int]] + ): """ Instantiate a control flow graph Parameters ---------- - blocks: List[BasicBlock] - List of basic blocks in the graph + bindings: List[GraphBnding] + List of bindings in the graph preds: List[List[int]] - The ith member is the list of predecessors to blocks[i] (given as indices in blocks) + The ith member is the list of predecessors to bindings[i] (given as indices in bindings) succs: List[List[int]] - The ith member is the list of successors to blocks[i] (given as indices in blocks) + The ith member is the list of successors to bindings[i] (given as indices in bindings) """ - if len(blocks) != len(preds) or len(blocks) != len(succs): + if len(bindings) != len(preds) or len(bindings) != len(succs): raise ValueError("The lengths of blocks, preds, and succs must all match.") return self.__init_handle_by_constructor__( - _ffi_api.ControlFlowGraph, blocks, preds, succs + _ffi_api.ControlFlowGraph, bindings, preds, succs ) # type: ignore @@ -136,20 +135,21 @@ def ExtractCFG(func: Function) -> ControlFlowGraph: def DataflowAnalysis( cfg: ControlFlowGraph, init: Any, - transfer_func: Callable[[BasicBlock, Any], Any], + transfer_func: Callable[[GraphBinding, Any], Any], merge_func: Callable[[Any, Any], Any], forward: bool = True, ) -> Tuple[List[Any], List[Any]]: """ - Generic dataflow analysis framework, based on Adrian Sampson's course notes: + Generic dataflow analysis framework, based on Adrian Sampson's course notes, + except binding by binding instead of basic block by basic block: https://www.cs.cornell.edu/courses/cs6120/2020fa/lesson/4/ - The analysis creates input and output maps (mapping basic block indices to a domain), - sets the initial input and output for each basic block to the init value, and then + The analysis creates input and output maps (mapping binding indices to a domain), + sets the initial input and output for each binding to the init value, and then performs a traversal of the CFG (BFS in this implementation, since unlike the general case, we do not have loops) and uses the transfer and merge function to update the inputs and - outputs. The analysis can proceed forwards (from block 0 onwards) or backwards (from the last - block back), flipping the roles of the input and output maps in the cases. + outputs. The analysis can proceed forwards (from binding 0 onwards) or backwards (from the last + binding back), flipping the roles of the input and output maps in the cases. Parameters ---------- @@ -159,23 +159,23 @@ def DataflowAnalysis( init: Any The initial value in the analysis domain to which all blocks should be initialized. - transfer_func: Callable[[BasicBlock, Any], Any] - Given a basic block and the input domain, compute the new output domain. + transfer_func: Callable[[GraphBinding, Any], Any] + Given a binding and the input domain, compute the new output domain. merge_func: Callable[[Any, Any], Any] When two output domains are fed into a single block (i.e., after an If branch), the merge function is used to combine them into a single domain. forward: bool - If true, the analysis proceeds forwards (starting from block 0 and going onwards). - If false, the analysis proceeds backwards (starting from the last block and going back). + If true, the analysis proceeds forwards (starting from binding 0 and going onwards). + If false, the analysis proceeds backwards (starting from the last binding and going back). The input and output maps play the opposite roles in forward and backward analyses. - I.e., in a backward analysis, the "final output" is the input map entry for block 0 - and the initial input is the output map entry for the last block. + I.e., in a backward analysis, the "final output" is the input map entry for binding 0 + and the initial input is the output map entry for the last binding. Returns ------- ret: Tuple[List[Any], List[Any]] A pair of the final input and output maps """ - return _ffi_api.DataflowAnalysis(forward, cfg, init, transfer_func, merge_func) # type: ignore + return _ffi_api.DataflowAnalysis(cfg, init, transfer_func, merge_func, forward) # type: ignore diff --git a/src/relax/analysis/dataflow_analysis.cc b/src/relax/analysis/dataflow_analysis.cc index 6a8f6721e0b5..8a30beb117e2 100644 --- a/src/relax/analysis/dataflow_analysis.cc +++ b/src/relax/analysis/dataflow_analysis.cc @@ -29,133 +29,108 @@ namespace tvm { namespace relax { -TVM_REGISTER_NODE_TYPE(BasicBlockNode); +TVM_REGISTER_NODE_TYPE(GraphBindingNode); -BasicBlock BasicBlock::Create(const SeqExpr& seq, const Array& args, const Expr& ret, - size_t start_block_idx, size_t start_binding_idx, - size_t end_block_idx, size_t end_binding_idx) { - ObjectPtr n = make_object(); +GraphBinding GraphBinding::Create(const SeqExpr& seq, const Array& args, size_t block_idx, + size_t binding_idx, BindingNodeKind kind) { + ObjectPtr n = make_object(); n->seq = seq; n->args = args; - n->ret = ret; - n->start_block_idx = start_block_idx; - n->start_binding_idx = start_binding_idx; - n->end_block_idx = end_block_idx; - n->end_binding_idx = end_binding_idx; - return BasicBlock(n); + n->block_idx = block_idx; + n->binding_idx = binding_idx; + n->kind = kind; + return GraphBinding(n); } TVM_REGISTER_NODE_TYPE(ControlFlowGraphNode); -ControlFlowGraph ControlFlowGraph::Create(const Array& blocks, +ControlFlowGraph ControlFlowGraph::Create(const Array& bindings, const Array>& preds, const Array>& succs) { ObjectPtr n = make_object(); - n->blocks = blocks; + n->bindings = bindings; n->preds = preds; n->succs = succs; return ControlFlowGraph(n); } -// Extracts a basic block and updates the running lists blocks, preds, and succs. -// The return value is the index of the final basic block processed in the seq expression +// Extracts a basic block and updates the running lists bindings, preds, and succs. +// The return value is the index of the final binding processed in the seq expression // (useful for processing branches). -size_t ExtractCFGHelper(const SeqExpr& seq, const Array& args, size_t start_block_idx, - size_t start_binding_idx, std::vector current_preds, - std::vector* blocks, std::vector>* preds, +size_t ExtractCFGHelper(const SeqExpr& seq, const Array& args, size_t block_idx, + size_t binding_idx, std::vector current_preds, + std::vector* bindings, + std::vector>* preds, std::vector>* succs) { - size_t end_block_idx = 0; - size_t end_binding_idx = 0; - Expr ret; - Optional branch_var; - Optional branch_expr; - - // go from the start index and continue until we hit the end of the block or a split point - bool hit_branch = false; - // note: if start_block_idx is past seq->blocks.size(), then the loop will not actually run - // and we will not hit a branch, so we will produce a basic block comprised only of the - // seq expr end expression - for (size_t i = start_block_idx; i < seq->blocks.size(); i++) { - for (size_t j = start_binding_idx; j < seq->blocks[i]->bindings.size(); j++) { - Binding binding = seq->blocks[i]->bindings[j]; - if (auto* var_binding = binding.as()) { - if (var_binding->value.as()) { - end_block_idx = i; - end_binding_idx = j; - branch_var = var_binding->var; - branch_expr = Downcast(var_binding->value); - ret = branch_expr.value()->cond; - hit_branch = true; - break; - } - } else if (auto* match_binding = binding.as()) { - if (match_binding->value.as()) { - end_block_idx = i; - end_binding_idx = j; - branch_var = var_binding->var; - branch_expr = Downcast(var_binding->value); - ret = branch_expr.value()->cond; - hit_branch = true; - break; - } - } else { - CHECK(false); // will never happen - } - } - if (hit_branch) { - break; - } + // case 1: We're past the end -> this is the block body (base case) + if (block_idx == seq->blocks.size()) { + bindings->push_back(GraphBinding::Create(seq, args, block_idx, 0U, BindingNodeKind::kSeqBody)); + preds->push_back(current_preds); + // the final binding has no successors + succs->push_back({}); + return bindings->size() - 1; } - if (!hit_branch) { - end_block_idx = seq->blocks.size(); - end_binding_idx = 0U; // doesn't matter which we use - ret = seq->body; - } - BasicBlock block = BasicBlock::Create(seq, args, ret, start_block_idx, start_binding_idx, - end_block_idx, end_binding_idx); - blocks->push_back(block); - size_t block_idx = blocks->size() - 1U; - succs->push_back({}); - preds->push_back(current_preds); - for (size_t pred : current_preds) { - succs->at(pred).push_back(block_idx); - } - // no branches: then we're done - if (!hit_branch) { - return block_idx; + Binding binding = seq->blocks[block_idx]->bindings[binding_idx]; + Expr binding_value; + if (auto* var_binding = binding.as()) { + binding_value = var_binding->value; + } else if (auto* match_binding = binding.as()) { + binding_value = match_binding->value; + } else { + CHECK(false) << "Invalid binding (should never happen)"; } - // hit a branch: recurse down the branches and then set up the merge block - SeqExpr true_branch = Downcast(branch_expr.value()->true_branch); - SeqExpr false_branch = Downcast(branch_expr.value()->false_branch); - // the branches could contain their own branches, which is why we return the final block index - size_t end_true = ExtractCFGHelper(true_branch, {}, 0U, 0U, {block_idx}, blocks, preds, succs); - size_t end_false = ExtractCFGHelper(false_branch, {}, 0U, 0U, {block_idx}, blocks, preds, succs); - - // work out the start indices for the merge point - size_t next_start_block_idx = end_block_idx; - size_t next_start_binding_idx = end_binding_idx; - // figure out the next indices - if (end_binding_idx == seq->blocks[end_block_idx]->bindings.size() - 1) { - if (end_block_idx == seq->blocks.size() - 1) { - next_start_block_idx = seq->blocks.size(); - next_start_binding_idx = 0U; - } else { - next_start_block_idx = end_block_idx + 1; - next_start_binding_idx = 0U; - } + + // case 2: Ordinary binding + if (!binding_value.as()) { + bindings->push_back( + GraphBinding::Create(seq, args, block_idx, binding_idx, BindingNodeKind::kBinding)); + size_t idx = bindings->size() - 1; + preds->push_back(current_preds); + // successor: the next binding (there will always be at least one binding after this, + // even if it's the seq body) + succs->push_back({idx + 1}); } else { - next_start_binding_idx = end_binding_idx + 1; + // case 3: dealing with a branch + auto if_node = Downcast(binding_value); + // start with the cond node + bindings->push_back( + GraphBinding::Create(seq, args, block_idx, binding_idx, BindingNodeKind::kIfCond)); + size_t idx = bindings->size() - 1; + preds->push_back(current_preds); + // there will be another successor, which we will add after recursing down the branches + succs->push_back({idx + 1}); + size_t final_true_idx = ExtractCFGHelper(Downcast(if_node->true_branch), {}, 0U, 0U, + {idx}, bindings, preds, succs); + succs->at(idx).push_back(final_true_idx + 1); + size_t final_false_idx = ExtractCFGHelper(Downcast(if_node->false_branch), {}, 0U, 0U, + {idx}, bindings, preds, succs); + // now create the merge + bindings->push_back( + GraphBinding::Create(seq, {}, block_idx, binding_idx, BindingNodeKind::kIfMerge)); + size_t merge_idx = bindings->size() - 1; + preds->push_back({final_true_idx, final_false_idx}); + succs->push_back({merge_idx + 1}); + // update the successors of the final true and false indices as well + succs->at(final_true_idx).push_back(merge_idx); + succs->at(final_false_idx).push_back(merge_idx); + } + // move on to next binding + size_t next_block_idx = block_idx; + size_t next_binding_idx = binding_idx + 1; + if (next_binding_idx >= seq->blocks[block_idx]->bindings.size()) { + next_block_idx = block_idx + 1; + next_binding_idx = 0U; } - return ExtractCFGHelper(seq, {branch_var.value()}, next_start_block_idx, next_start_binding_idx, - {end_true, end_false}, blocks, preds, succs); + return ExtractCFGHelper(seq, {}, next_block_idx, next_binding_idx, {bindings->size() - 1}, + bindings, preds, succs); } ControlFlowGraph ExtractCFG(const Function& func) { - std::vector blocks; + std::vector bindings; std::vector> preds; std::vector> succs; - ExtractCFGHelper(Downcast(func->body), func->params, 0U, 0U, {}, &blocks, &preds, + ExtractCFGHelper(Downcast(func->body), func->params, 0U, 0U, {}, &bindings, &preds, &succs); Array> pred_arr; @@ -174,16 +149,16 @@ ControlFlowGraph ExtractCFG(const Function& func) { } succ_arr.push_back(succ_ints); } - return ControlFlowGraph::Create(Array(blocks), pred_arr, succ_arr); + return ControlFlowGraph::Create(Array(bindings), pred_arr, succ_arr); } std::pair, Array> DataflowAnalysis( const ControlFlowGraph& cfg, const ObjectRef& init, - std::function transfer_func, + std::function transfer_func, std::function merge_func, bool forward) { std::vector in_map; std::vector out_map; - for (size_t i = 0; i < cfg->blocks.size(); i++) { + for (size_t i = 0; i < cfg->bindings.size(); i++) { in_map.push_back(init); out_map.push_back(init); } @@ -192,7 +167,7 @@ std::pair, Array> DataflowAnalysis( // Since there are no loops in our AST, one traversal through the CFG suffices. // We will do BFS std::queue worklist; - worklist.push((forward) ? 0 : cfg->blocks.size() - 1); + worklist.push((forward) ? 0 : cfg->bindings.size() - 1); while (!worklist.empty()) { size_t idx = worklist.front(); worklist.pop(); @@ -210,7 +185,7 @@ std::pair, Array> DataflowAnalysis( : (prev.size() == 1) ? results->at(prev[0].IntValue()) : merge_func(results->at(prev[0].IntValue()), results->at(prev[1].IntValue())); - results->operator[](idx) = transfer_func(cfg->blocks[idx], inputs->at(idx)); + results->operator[](idx) = transfer_func(cfg->bindings[idx], inputs->at(idx)); for (Integer next_idx : next) { worklist.push(next_idx.IntValue()); @@ -220,16 +195,15 @@ std::pair, Array> DataflowAnalysis( return {Array(in_map), Array(out_map)}; } -TVM_REGISTER_GLOBAL("relax.analysis.BasicBlock") - .set_body_typed([](const SeqExpr& seq, const Array& args, const Expr& ret, - size_t start_block_idx, size_t start_binding_idx, size_t end_block_idx, - size_t end_binding_idx) { - return BasicBlock::Create(seq, args, ret, start_block_idx, start_binding_idx, end_block_idx, - end_binding_idx); +TVM_REGISTER_GLOBAL("relax.analysis.GraphBinding") + .set_body_typed([](const SeqExpr& seq, const Array& args, size_t block_idx, + size_t binding_idx, int kind) { + return GraphBinding::Create(seq, args, block_idx, binding_idx, + static_cast(kind)); }); TVM_REGISTER_GLOBAL("relax.analysis.ControlFlowGraph") - .set_body_typed([](const Array& blocks, const Array>& preds, + .set_body_typed([](const Array& blocks, const Array>& preds, const Array>& succs) { return ControlFlowGraph::Create(blocks, preds, succs); }); diff --git a/tests/python/relax/test_dataflow_analysis.py b/tests/python/relax/test_dataflow_analysis.py index 5013a94674b4..a6fb1e50dfd1 100644 --- a/tests/python/relax/test_dataflow_analysis.py +++ b/tests/python/relax/test_dataflow_analysis.py @@ -19,9 +19,9 @@ from tvm import relax from tvm.relax.analysis.dataflow_analysis import ( ControlFlowGraph, - BasicBlock, ExtractCFG, DataflowAnalysis, + BindingNodeKind, ) from tvm.script import ir as I, relax as R import tvm.testing @@ -44,45 +44,34 @@ def assert_pred_succ_lists(graph: ControlFlowGraph, expected_preds: List[List[in ) -# common pattern in normalization that we can check for: -# if condition: -# ... -# z = value1 -# else: -# ... -# z = value2 -# -# results in: -# -# VarBinding( -# z, -# If( -# condition, -# SeqExpr([..., BindingBlock([..., VarBinding(new_var1, value1)])], body=new_var1), -# SeqExpr([..., BindingBlock([..., VarBinding(new_var2, value2)])], body=new_var2) -# ) -# ) -# This function can be used for checking the SeqExprs inside the branches -def assert_ret_is_final_binding_in_seq(block: BasicBlock, check_op: Optional[str] = None): - seq_body = block.seq.body - final_binding = block.seq.blocks[-1].bindings[-1] - assert seq_body == final_binding.var - assert block.ret == seq_body - if check_op is not None: - assert isinstance(final_binding.value, relax.Call) - assert final_binding.value.op.name == check_op - - -# ensure that the exprs in each list match each other and that they do not match those in the other lists -def assert_distinct(*groups: List[relax.Expr]): +def assert_binding_fields( + graph: ControlFlowGraph, + idx: int, + block_idx: int, + binding_idx: int, + kind: BindingNodeKind = BindingNodeKind.kBinding, + args: Optional[List[relax.Var]] = None, +): + binding = graph.bindings[idx] + assert binding.block_idx == block_idx + assert binding.binding_idx == binding_idx + assert binding.kind == kind.value + if args is not None: + assert len(binding.args) == len(args) + for i in range(len(args)): + assert binding.args[i] == args[i] + + +# assert that the SeqExprs for each bindings match within groups and do not match other groups +def assert_distinct_seqs(cfg: ControlFlowGraph, *groups: List[int]): for i, group in enumerate(groups): if len(group) == 0: continue - for item in group[1:]: - assert item == group[0] + for idx in group[1:]: + assert cfg.bindings[idx].seq == cfg.bindings[group[0]].seq for other_group in groups[i + 1 :]: - for item in other_group: - assert group[0] != item + for idx in other_group: + assert cfg.bindings[group[0]].seq != cfg.bindings[idx].seq def test_trivial_CFG(): @@ -93,12 +82,9 @@ def main() -> R.Tensor((), "int32"): return R.const(1, dtype="int32") graph = ExtractCFG(TrivialFunc["main"]) - assert len(graph.blocks) == 1 + assert len(graph.bindings) == 1 assert_pred_succ_lists(graph, [[]]) - assert graph.blocks[0].ret == TrivialFunc["main"].body.body - assert graph.blocks[0].start_block_idx == 0 - assert graph.blocks[0].start_binding_idx == 0 - assert graph.blocks[0].end_block_idx == 0 + assert_binding_fields(graph, 0, 0, 0, kind=BindingNodeKind.kSeqBody) def test_sequence_of_bindings(): @@ -112,13 +98,12 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): return q graph = ExtractCFG(FuncWithBindings["main"]) - assert len(graph.blocks) == 1 - assert_pred_succ_lists(graph, [[]]) - assert graph.blocks[0].ret == FuncWithBindings["main"].body.body - assert graph.blocks[0].args == FuncWithBindings["main"].params - assert graph.blocks[0].start_block_idx == 0 - assert graph.blocks[0].start_binding_idx == 0 - assert graph.blocks[0].end_block_idx == 1 + assert len(graph.bindings) == 4 + assert_pred_succ_lists(graph, [[], [0], [1], [2]]) + assert_binding_fields(graph, 0, 0, 0, args=[FuncWithBindings["main"].params[0]]) + assert_binding_fields(graph, 1, 0, 1) + assert_binding_fields(graph, 2, 0, 2) + assert_binding_fields(graph, 3, 1, 0, kind=BindingNodeKind.kSeqBody) def test_dataflow_block(): @@ -140,14 +125,16 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): return u graph = ExtractCFG(FuncWithDataflow["main"]) - assert len(graph.blocks) == 1 - assert_pred_succ_lists(graph, [[]]) - assert graph.blocks[0].ret == FuncWithDataflow["main"].body.body - assert graph.blocks[0].args == FuncWithDataflow["main"].params - assert graph.blocks[0].start_block_idx == 0 - assert graph.blocks[0].start_binding_idx == 0 - # there are four binding blocks but they form one basic block - assert graph.blocks[0].end_block_idx == 4 + assert len(graph.bindings) == 8 + assert_pred_succ_lists(graph, [[], [0], [1], [2], [3], [4], [5], [6]]) + assert_binding_fields(graph, 0, 0, 0, args=FuncWithDataflow["main"].params) + assert_binding_fields(graph, 1, 0, 1) + assert_binding_fields(graph, 2, 1, 0) + assert_binding_fields(graph, 3, 1, 1) + assert_binding_fields(graph, 4, 2, 0) + assert_binding_fields(graph, 5, 3, 0) + assert_binding_fields(graph, 6, 3, 1) + assert_binding_fields(graph, 7, 4, 0, kind=BindingNodeKind.kSeqBody) def test_simple_branch(): @@ -165,44 +152,27 @@ def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): z = R.multiply(y, y) return z - # basic blocks: - # 1. the starting block (no bindings) whose return is the branch condition - # 2. the true branch body (return: R.multiply(y, y)) - # 3. the false branch body (return: R.multiply(y, y)) - # 4. the merge block (no bindings, argument is z) whose return is z graph = ExtractCFG(SimpleBranch["main"]) - assert len(graph.blocks) == 4 - assert_pred_succ_lists(graph, [[], [0], [0], [1, 2]]) - - assert graph.blocks[0].args == SimpleBranch["main"].params - assert graph.blocks[0].ret == SimpleBranch["main"].params[0] - assert graph.blocks[0].start_block_idx == 0 - assert graph.blocks[0].start_binding_idx == 0 - assert graph.blocks[0].end_block_idx == 0 - assert graph.blocks[0].end_binding_idx == 0 - - assert len(graph.blocks[1].args) == 0 - assert_ret_is_final_binding_in_seq(graph.blocks[1], "relax.multiply") - assert graph.blocks[1].start_block_idx == 0 - assert graph.blocks[1].start_binding_idx == 0 - assert graph.blocks[1].end_block_idx == 1 - - assert len(graph.blocks[2].args) == 0 - assert_ret_is_final_binding_in_seq(graph.blocks[2], "relax.multiply") - assert graph.blocks[2].start_block_idx == 0 - assert graph.blocks[2].start_binding_idx == 0 - assert graph.blocks[2].end_block_idx == 1 - - assert len(graph.blocks[3].args) == 1 - assert graph.blocks[3].args[0].name_hint == "z" - assert graph.blocks[3].ret == SimpleBranch["main"].body.body - # the if was the last binding in the block, so we're past the end - assert graph.blocks[3].start_block_idx == 1 - assert graph.blocks[3].end_block_idx == 1 - - assert_distinct( - [graph.blocks[0].seq, graph.blocks[3].seq], [graph.blocks[1]], [graph.blocks[2]] + + # cond binding + 3 bindings in true branch + true branch end + # + 3 bindings in false branch + false branch end + merge + seq body + assert len(graph.bindings) == 11 + assert_pred_succ_lists(graph, [[], [0], [1], [2], [3], [0], [5], [6], [7], [4, 8], [9]]) + + assert_binding_fields( + graph, 0, 0, 0, kind=BindingNodeKind.kIfCond, args=SimpleBranch["main"].params ) + assert_binding_fields(graph, 1, 0, 0) + assert_binding_fields(graph, 2, 0, 1) + assert_binding_fields(graph, 3, 0, 2) + assert_binding_fields(graph, 4, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 5, 0, 0) + assert_binding_fields(graph, 6, 0, 1) + assert_binding_fields(graph, 7, 0, 2) + assert_binding_fields(graph, 8, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 9, 0, 0, kind=BindingNodeKind.kIfMerge) + assert_binding_fields(graph, 10, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_distinct_seqs(graph, [0, 9], [1, 4], [5, 8]) def test_bindings_after_branch(): @@ -220,42 +190,19 @@ def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): return q graph = ExtractCFG(BranchAndBind["main"]) - assert len(graph.blocks) == 4 - assert_pred_succ_lists(graph, [[], [0], [0], [1, 2]]) - - # same as above example, except there are bindings preceding the if (included in block 0) - # and after the if (included in block 3) - - assert graph.blocks[0].args == BranchAndBind["main"].params - assert graph.blocks[0].ret == BranchAndBind["main"].params[0] - assert graph.blocks[0].start_block_idx == 0 - assert graph.blocks[0].start_binding_idx == 0 - assert graph.blocks[0].end_block_idx == 0 - assert graph.blocks[0].end_binding_idx == 2 - - assert len(graph.blocks[1].args) == 0 - assert_ret_is_final_binding_in_seq(graph.blocks[1], "relax.multiply") - assert graph.blocks[1].start_block_idx == 0 - assert graph.blocks[1].start_binding_idx == 0 - assert graph.blocks[1].end_block_idx == 1 - - assert len(graph.blocks[2].args) == 0 - assert_ret_is_final_binding_in_seq(graph.blocks[2], "relax.add") - assert graph.blocks[2].start_block_idx == 0 - assert graph.blocks[2].start_binding_idx == 0 - assert graph.blocks[2].end_block_idx == 1 - - assert len(graph.blocks[3].args) == 1 - assert graph.blocks[3].args[0].name_hint == "z" - assert graph.blocks[3].ret.name_hint == "q" - assert graph.blocks[3].start_block_idx == 0 - assert graph.blocks[3].start_binding_idx == 3 - assert graph.blocks[3].end_block_idx == 1 - assert graph.blocks[3].end_binding_idx == 0 - - assert_distinct( - [graph.blocks[0].seq, graph.blocks[3].seq], [graph.blocks[1]], [graph.blocks[2]] - ) + assert len(graph.bindings) == 10 + assert_pred_succ_lists(graph, [[], [0], [1], [2], [3], [2], [5], [4, 6], [7], [8]]) + assert_binding_fields(graph, 0, 0, 0, args=BranchAndBind["main"].params) + assert_binding_fields(graph, 1, 0, 1) + assert_binding_fields(graph, 2, 0, 2, kind=BindingNodeKind.kIfCond) + assert_binding_fields(graph, 3, 0, 0) + assert_binding_fields(graph, 4, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 5, 0, 0) + assert_binding_fields(graph, 6, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 7, 0, 2, kind=BindingNodeKind.kIfMerge) + assert_binding_fields(graph, 8, 0, 3) + assert_binding_fields(graph, 9, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_distinct_seqs(graph, [0, 2, 7, 9], [3, 4], [5, 6]) def test_branch_with_multiple_blocks(): @@ -287,38 +234,54 @@ def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): graph = ExtractCFG(LongBranches["main"]) # empty entry block, one block for each branch, and an empty exit block - assert len(graph.blocks) == 4 - assert_pred_succ_lists(graph, [[], [0], [0], [1, 2]]) - - assert graph.blocks[0].args == LongBranches["main"].params - assert graph.blocks[0].ret == LongBranches["main"].params[0] - assert graph.blocks[0].start_block_idx == 0 - assert graph.blocks[0].start_binding_idx == 0 - assert graph.blocks[0].end_block_idx == 0 - assert graph.blocks[0].end_binding_idx == 0 - - # there are 3 binding blocks included in each branch - assert len(graph.blocks[1].args) == 0 - assert_ret_is_final_binding_in_seq(graph.blocks[1], "relax.multiply") - assert graph.blocks[1].start_block_idx == 0 - assert graph.blocks[1].start_binding_idx == 0 - assert graph.blocks[1].end_block_idx == 3 - - assert len(graph.blocks[2].args) == 0 - assert_ret_is_final_binding_in_seq(graph.blocks[2], "relax.add") - assert graph.blocks[2].start_block_idx == 0 - assert graph.blocks[2].start_binding_idx == 0 - assert graph.blocks[2].end_block_idx == 3 - - assert len(graph.blocks[3].args) == 1 - assert graph.blocks[3].args[0].name_hint == "r" - assert graph.blocks[3].ret.name_hint == "r" - assert graph.blocks[3].start_block_idx == 1 - assert graph.blocks[3].end_block_idx == 1 - - assert_distinct( - [graph.blocks[0].seq, graph.blocks[3].seq], [graph.blocks[1]], [graph.blocks[2]] + assert len(graph.bindings) == 19 + assert_pred_succ_lists( + graph, + [ + [], + [0], + [1], + [2], + [3], + [4], + [5], + [6], + [7], + [0], + [9], + [10], + [11], + [12], + [13], + [14], + [15], + [8, 16], + [17], + ], + ) + + assert_binding_fields( + graph, 0, 0, 0, kind=BindingNodeKind.kIfCond, args=LongBranches["main"].params ) + assert_binding_fields(graph, 1, 0, 0) + assert_binding_fields(graph, 2, 0, 1) + assert_binding_fields(graph, 3, 1, 0) + assert_binding_fields(graph, 4, 1, 1) + assert_binding_fields(graph, 5, 1, 2) + assert_binding_fields(graph, 6, 2, 0) + assert_binding_fields(graph, 7, 2, 1) + assert_binding_fields(graph, 8, 3, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 9, 0, 0) + assert_binding_fields(graph, 10, 0, 1) + assert_binding_fields(graph, 11, 1, 0) + assert_binding_fields(graph, 12, 1, 1) + assert_binding_fields(graph, 13, 1, 2) + assert_binding_fields(graph, 14, 2, 0) + assert_binding_fields(graph, 15, 2, 1) + assert_binding_fields(graph, 16, 3, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 17, 0, 0, kind=BindingNodeKind.kIfMerge) + assert_binding_fields(graph, 18, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_distinct_seqs(graph, [0, 17, 18], [1, 8], [9, 16]) def test_nested_branches(): @@ -344,91 +307,68 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): return z graph = ExtractCFG(NestedBranches["main"]) - # basic blocks: entry block to func, entry block to true branch, true branch in true branch, - # false branch in true branch, merge block in true branch, - # entry to false branch, true branch in false branch, false branch in false branch, - # merge block in false branch, merge block in outer function - assert len(graph.blocks) == 10 + assert len(graph.bindings) == 22 assert_pred_succ_lists( graph, [ - [], # function entry - [0], # true branch entry - [1], # true branch's true branch - [1], # true branch's false branch - [2, 3], # true branch's exit - [0], # false branch entry - [5], # false branch's true branch - [5], # false branch's false branch - [6, 7], # false branch exit - [4, 8], # function exit + [], # first binding + [0], # branch cond + [1], # first binding in true branch + [2], # mested if condition + [3], # binding inside nested true branch + [4], # end of nested true branch + [3], # binding inside nested false branch + [6], # end of nested false branch + [5, 7], # merge for nested if + [8], # binding after nested if + [9], # end of outer true branch + [1], # first binding in false branch + [11], # nested if condition, + [12], # binding inside nested true branch + [13], # end of nested true branch + [12], # binding inside nested false branch + [15], # end of nested false branch + [14, 16], # merge after nested if + [17], # binding after nested if + [18], # end of outer false branch + [10, 19], # merge after outer if + [20], # end of body ], ) - assert graph.blocks[0].args == NestedBranches["main"].params - assert graph.blocks[0].ret.name_hint == "cond1" - assert graph.blocks[0].start_block_idx == 0 - assert graph.blocks[0].start_binding_idx == 0 - assert graph.blocks[0].end_block_idx == 0 - assert graph.blocks[0].end_binding_idx == 1 - - assert len(graph.blocks[1].args) == 0 - assert graph.blocks[1].ret.name_hint == "cond2" - assert graph.blocks[1].start_block_idx == 0 - assert graph.blocks[1].start_binding_idx == 0 - assert graph.blocks[1].end_block_idx == 0 - assert graph.blocks[1].end_binding_idx == 1 - - assert len(graph.blocks[2].args) == 0 - assert_ret_is_final_binding_in_seq(graph.blocks[2], "relax.add") - assert graph.blocks[2].start_block_idx == 0 - assert graph.blocks[2].start_binding_idx == 0 - assert graph.blocks[2].end_block_idx == 1 - - assert len(graph.blocks[3].args) == 0 - assert_ret_is_final_binding_in_seq(graph.blocks[3], "relax.multiply") - assert graph.blocks[3].start_block_idx == 0 - assert graph.blocks[3].start_binding_idx == 0 - assert graph.blocks[3].end_block_idx == 1 - - assert len(graph.blocks[4].args) == 1 - assert graph.blocks[4].args[0].name_hint == "y" - assert_ret_is_final_binding_in_seq(graph.blocks[4], "relax.add") - assert graph.blocks[4].start_block_idx == 0 - assert graph.blocks[4].start_binding_idx == 2 - assert graph.blocks[4].end_block_idx == 1 - - assert len(graph.blocks[5].args) == 0 - assert graph.blocks[5].ret.name_hint == "cond3" - assert graph.blocks[5].start_block_idx == 0 - assert graph.blocks[5].start_binding_idx == 0 - assert graph.blocks[5].end_block_idx == 0 - assert graph.blocks[5].end_binding_idx == 1 - - assert len(graph.blocks[6].args) == 0 - assert_ret_is_final_binding_in_seq(graph.blocks[6], "relax.multiply") - assert graph.blocks[6].start_block_idx == 0 - assert graph.blocks[6].start_binding_idx == 0 - assert graph.blocks[6].end_block_idx == 1 - - assert len(graph.blocks[7].args) == 0 - assert_ret_is_final_binding_in_seq(graph.blocks[7], "relax.add") - assert graph.blocks[7].start_block_idx == 0 - assert graph.blocks[7].start_binding_idx == 0 - assert graph.blocks[7].end_block_idx == 1 - - assert len(graph.blocks[8].args) == 1 - assert graph.blocks[8].args[0].name_hint == "y" - assert_ret_is_final_binding_in_seq(graph.blocks[8], "relax.multiply") - assert graph.blocks[8].start_block_idx == 0 - assert graph.blocks[8].start_binding_idx == 2 - assert graph.blocks[8].end_block_idx == 1 - - assert len(graph.blocks[9].args) == 1 - assert graph.blocks[9].args[0].name_hint == "z" - assert graph.blocks[9].ret.name_hint == "z" - assert graph.blocks[9].start_block_idx == 1 - assert graph.blocks[9].end_block_idx == 1 + assert_binding_fields(graph, 0, 0, 0, args=NestedBranches["main"].params) + assert_binding_fields(graph, 1, 0, 1, kind=BindingNodeKind.kIfCond) + assert_binding_fields(graph, 2, 0, 0) + assert_binding_fields(graph, 3, 0, 1, kind=BindingNodeKind.kIfCond) + assert_binding_fields(graph, 4, 0, 0) + assert_binding_fields(graph, 5, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 6, 0, 0) + assert_binding_fields(graph, 7, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 8, 0, 1, kind=BindingNodeKind.kIfMerge) + assert_binding_fields(graph, 9, 0, 2) + assert_binding_fields(graph, 10, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 11, 0, 0) + assert_binding_fields(graph, 12, 0, 1, kind=BindingNodeKind.kIfCond) + assert_binding_fields(graph, 13, 0, 0) + assert_binding_fields(graph, 14, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 15, 0, 0) + assert_binding_fields(graph, 16, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 17, 0, 1, kind=BindingNodeKind.kIfMerge) + assert_binding_fields(graph, 18, 0, 2) + assert_binding_fields(graph, 19, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 20, 0, 1, kind=BindingNodeKind.kIfMerge) + assert_binding_fields(graph, 21, 1, 0, kind=BindingNodeKind.kSeqBody) + + assert_distinct_seqs( + graph, + [0, 1, 20, 21], + [2, 3, 8, 9, 10], + [4, 5], + [6, 7], + [11, 12, 17, 18, 19], + [13, 14], + [15, 16], + ) def test_simple_analysis(): @@ -438,7 +378,7 @@ class TrivialFunc: def main() -> R.Tensor((), "int32"): return R.const(1, dtype="int32") - # only one basic block to consider here + # only one binding to consider here init = {"a": 1} def transfer_func(_, domain): @@ -503,30 +443,42 @@ def merge_func(domain1, domain2): new_domain["merge"] = 1 return new_domain - def check_expected_maps(in_map, out_map, forward=True): - # merge will happen in the last block only - i = 0 if forward else 3 - assert len(in_map[i]) == 1 and len(out_map[i]) == 1 - assert in_map[i]["a"] == 1 - assert out_map[i]["a"] == 2 - - for j in (1, 2): - assert len(in_map[j]) == 1 and len(out_map[j]) == 1 - assert in_map[j]["a"] == 2 - assert out_map[j]["a"] == 3 - - i = 3 if forward else 0 - assert len(in_map[i]) == 2 and len(out_map[i]) == 2 - assert in_map[i]["a"] == 3 - assert in_map[i]["merge"] == 1 - assert out_map[i]["a"] == 4 - assert out_map[i]["merge"] == 2 - cfg = ExtractCFG(SimpleBranch["main"]) in_map, out_map = DataflowAnalysis(cfg, init, transfer_func, merge_func, forward=True) - check_expected_maps(in_map, out_map, forward=True) + # start and true branch + for i in range(5): + assert in_map[i]["a"] == i + 1 + assert out_map[i]["a"] == i + 2 + # false branch + for i in range(5, 9): + assert in_map[i]["a"] == i - 3 + assert out_map[i]["a"] == i - 2 + # index 9 is the merge + assert in_map[9]["a"] == 6 + assert in_map[9]["merge"] == 1 + assert out_map[9]["a"] == 7 + assert out_map[9]["merge"] == 2 + # index 10 is the last + assert in_map[10]["a"] == 7 + assert in_map[10]["merge"] == 2 + assert out_map[10]["a"] == 8 + assert out_map[10]["merge"] == 3 + in_map, out_map = DataflowAnalysis(cfg, init, transfer_func, merge_func, forward=False) - check_expected_maps(out_map, in_map, forward=False) + # backward direction: start with index 10 + # end of seq through false branch + for i in range(6): + assert out_map[10 - i]["a"] == i + 1 + assert in_map[10 - i]["a"] == i + 2 + # true branch + for i in range(4): + assert out_map[4 - i]["a"] == i + 3 + assert in_map[4 - i]["a"] == i + 4 + # the if condition is the merge + assert out_map[0]["a"] == 7 + assert out_map[0]["merge"] == 1 + assert in_map[0]["a"] == 8 + assert in_map[0]["merge"] == 2 if __name__ == "__main__": From 9eaa2711328875fcd7999931d13c67ba0277b6b1 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 6 Sep 2023 16:29:59 -0400 Subject: [PATCH 05/18] Add GetBoundValue utility function --- src/relax/analysis/dataflow_analysis.cc | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/relax/analysis/dataflow_analysis.cc b/src/relax/analysis/dataflow_analysis.cc index 8a30beb117e2..a4c6fbbbf8de 100644 --- a/src/relax/analysis/dataflow_analysis.cc +++ b/src/relax/analysis/dataflow_analysis.cc @@ -72,14 +72,7 @@ size_t ExtractCFGHelper(const SeqExpr& seq, const Array& args, size_t block } Binding binding = seq->blocks[block_idx]->bindings[binding_idx]; - Expr binding_value; - if (auto* var_binding = binding.as()) { - binding_value = var_binding->value; - } else if (auto* match_binding = binding.as()) { - binding_value = match_binding->value; - } else { - CHECK(false) << "Invalid binding (should never happen)"; - } + Expr binding_value = GetBoundValue(binding); // case 2: Ordinary binding if (!binding_value.as()) { From 2f9abb9664774d7e6eec76cfaeb39b552a74a2ed Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 6 Sep 2023 16:30:24 -0400 Subject: [PATCH 06/18] Add GetBindingIndex helper function --- include/tvm/relax/dataflow_analysis.h | 21 +++++++---- .../tvm/relax/analysis/dataflow_analysis.py | 36 +++++++++++++++++++ src/relax/analysis/dataflow_analysis.cc | 35 ++++++++++++++++++ tests/python/relax/test_dataflow_analysis.py | 32 +++++++++++++++++ 4 files changed, 118 insertions(+), 6 deletions(-) diff --git a/include/tvm/relax/dataflow_analysis.h b/include/tvm/relax/dataflow_analysis.h index f8782e9a89ad..65038be65e05 100644 --- a/include/tvm/relax/dataflow_analysis.h +++ b/include/tvm/relax/dataflow_analysis.h @@ -46,12 +46,7 @@ namespace relax { * the SeqExprs in the true and false branches) * 4. The body expression in a SeqExpr (not actually bound) */ -enum BindingNodeKind : int { - kBinding = 0, - kIfCond = 1, - kIfMerge = 2, - kSeqBody = 3 -}; +enum BindingNodeKind : int { kBinding = 0, kIfCond = 1, kIfMerge = 2, kSeqBody = 3 }; class GraphBindingNode : public Object { public: @@ -182,6 +177,20 @@ std::pair, Array> DataflowAnalysis( std::function transfer_func, std::function merge_func, bool forward = true); +/*! \brief A helper function. Given an index into a SeqExpr, give the index of the GraphBinding + * in the CFG. + * + * \param cfg The control flow graph. + * \param seq The target SeqExpr. + * \param block_idx The target block in the SeqExpr. + * Convention: Use one past the last block to indicate the SeqExpr body. + * \param binding_idx The target binding in the target block. + * \param match_cond If the RHS of the target binding is an IfExpr, then if match_cond is true, + * the returned index will be for the condition node; otherwise it will be for the merge node. + */ +size_t GetBindingIndex(const ControlFlowGraph& cfg, const SeqExpr& seq, size_t block_idx, + size_t binding_idx, bool match_cond); + } // namespace relax } // namespace tvm #endif \ No newline at end of file diff --git a/python/tvm/relax/analysis/dataflow_analysis.py b/python/tvm/relax/analysis/dataflow_analysis.py index e942e60f855e..ef55ca137cc0 100644 --- a/python/tvm/relax/analysis/dataflow_analysis.py +++ b/python/tvm/relax/analysis/dataflow_analysis.py @@ -132,6 +132,42 @@ def ExtractCFG(func: Function) -> ControlFlowGraph: return _ffi_api.ExtractCFG(func) # type: ignore +def GetBindingIndex( + cfg: ControlFlowGraph, seq: SeqExpr, block_idx: int, binding_idx: int, match_cond: bool = False +) -> int: + """ + Helper function. Given a control flow graph and a seq expression with a block index and binding index, + return the index of the corresponding GraphBinding in the CFG + + Parameters + ---------- + cfg: ControlFlowGraph + The control flow graph. + + seq: SeqExpr + The target SeqExpr. + + block_idx: int + The index of the target block in seq. + Convention: If the target is `seq.body`, block_idx should be one past the last block + (i.e., it should be equal to `len(seq.blocks)`). + + binding_idx: int + The index of the target binding in the target block. + + match_cond: bool + If true and the target binding in seq is an IfNode, then this function will return + the binding index corresponding to the If condition. + If false, then this function will return the binding index corresponding to the If merge. + + Returns + ------- + idx: int + The index of the corresponding GraphBindindg in `cfg.bindings`. + """ + return _ffi_api.GetBindingIndex(cfg, seq, block_idx, binding_idx, match_cond) # type: ignore + + def DataflowAnalysis( cfg: ControlFlowGraph, init: Any, diff --git a/src/relax/analysis/dataflow_analysis.cc b/src/relax/analysis/dataflow_analysis.cc index a4c6fbbbf8de..46ada151ddbb 100644 --- a/src/relax/analysis/dataflow_analysis.cc +++ b/src/relax/analysis/dataflow_analysis.cc @@ -188,6 +188,34 @@ std::pair, Array> DataflowAnalysis( return {Array(in_map), Array(out_map)}; } +size_t GetBindingIndex(const ControlFlowGraph& cfg, const SeqExpr& seq, size_t block_idx, + size_t binding_idx, bool match_cond) { + bool is_body = (block_idx == seq->blocks.size()); + bool is_if = + (!is_body && (GetBoundValue(seq->blocks[block_idx]->bindings[binding_idx]).as())); + + // This is an inefficient linear scan; it could be improved by keeping a map of + // SeqExprs to indices in the CFG data structure. + // That should be considered if this function poses performance issues (unlikely). + for (size_t i = 0; i < cfg->bindings.size(); i++) { + auto binding = cfg->bindings[i]; + if (binding->seq != seq) { + continue; + } + if (is_body && binding->kind == BindingNodeKind::kSeqBody) { + return i; + } + if (binding->block_idx == block_idx && binding->binding_idx == binding_idx) { + if (!is_if || (match_cond && binding->kind == BindingNodeKind::kIfCond) || + (!match_cond && binding->kind == BindingNodeKind::kIfMerge)) { + return i; + } + } + } + CHECK(false) << "Target binding does not appear in the given CFG"; + return cfg->bindings.size(); +} + TVM_REGISTER_GLOBAL("relax.analysis.GraphBinding") .set_body_typed([](const SeqExpr& seq, const Array& args, size_t block_idx, size_t binding_idx, int kind) { @@ -210,5 +238,12 @@ TVM_REGISTER_GLOBAL("relax.analysis.DataflowAnalysis") return Array({ret.first, ret.second}); }); +// need to turn the size_t's into ints in order to cross the C++<->Python boundary +TVM_REGISTER_GLOBAL("relax.analysis.GetBindingIndex") + .set_body_typed([](const ControlFlowGraph& cfg, const SeqExpr& seq, int block_idx, + int binding_idx, bool match_cond) -> int { + return GetBindingIndex(cfg, seq, block_idx, binding_idx, match_cond); + }); + } // namespace relax } // namespace tvm \ No newline at end of file diff --git a/tests/python/relax/test_dataflow_analysis.py b/tests/python/relax/test_dataflow_analysis.py index a6fb1e50dfd1..86c8c5dc23e2 100644 --- a/tests/python/relax/test_dataflow_analysis.py +++ b/tests/python/relax/test_dataflow_analysis.py @@ -22,6 +22,7 @@ ExtractCFG, DataflowAnalysis, BindingNodeKind, + GetBindingIndex, ) from tvm.script import ir as I, relax as R import tvm.testing @@ -481,5 +482,36 @@ def merge_func(domain1, domain2): assert in_map[0]["merge"] == 2 +def test_get_binding_index(): + @I.ir_module + class BranchAndBind: + @R.function + def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): + x = R.const(1, dtype="int32") + y = R.add(x, x) + if cond: + z = R.multiply(y, y) + else: + z = R.add(y, y) + q = R.add(z, z) + return q + + graph = ExtractCFG(BranchAndBind["main"]) + outer_seq = BranchAndBind["main"].body + true_seq = BranchAndBind["main"].body.blocks[0].bindings[2].value.true_branch + false_seq = BranchAndBind["main"].body.blocks[0].bindings[2].value.false_branch + + assert GetBindingIndex(graph, outer_seq, 0, 0) == 0 + assert GetBindingIndex(graph, outer_seq, 0, 1) == 1 + assert GetBindingIndex(graph, outer_seq, 0, 2, match_cond=True) == 2 + assert GetBindingIndex(graph, true_seq, 0, 0) == 3 + assert GetBindingIndex(graph, true_seq, 1, 0) == 4 + assert GetBindingIndex(graph, false_seq, 0, 0) == 5 + assert GetBindingIndex(graph, false_seq, 1, 0) == 6 + assert GetBindingIndex(graph, outer_seq, 0, 2) == 7 # the merge + assert GetBindingIndex(graph, outer_seq, 0, 3) == 8 + assert GetBindingIndex(graph, outer_seq, 1, 0) == 9 + + if __name__ == "__main__": tvm.testing.main() From fd7ccbfa5f4d13c07601ea6d3b01d4e95e7ae95c Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 6 Sep 2023 16:57:15 -0400 Subject: [PATCH 07/18] Fixing naming convention in Python dataflow analysis functions --- .../tvm/relax/analysis/dataflow_analysis.py | 6 +-- tests/python/relax/test_dataflow_analysis.py | 54 +++++++++---------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/python/tvm/relax/analysis/dataflow_analysis.py b/python/tvm/relax/analysis/dataflow_analysis.py index ef55ca137cc0..a7a2446a47bf 100644 --- a/python/tvm/relax/analysis/dataflow_analysis.py +++ b/python/tvm/relax/analysis/dataflow_analysis.py @@ -114,7 +114,7 @@ def __init__( ) # type: ignore -def ExtractCFG(func: Function) -> ControlFlowGraph: +def extract_cfg(func: Function) -> ControlFlowGraph: """ Given a Relax function, produces the corresponding control flow graph. The function is expected to have been normalized. @@ -132,7 +132,7 @@ def ExtractCFG(func: Function) -> ControlFlowGraph: return _ffi_api.ExtractCFG(func) # type: ignore -def GetBindingIndex( +def get_binding_index( cfg: ControlFlowGraph, seq: SeqExpr, block_idx: int, binding_idx: int, match_cond: bool = False ) -> int: """ @@ -168,7 +168,7 @@ def GetBindingIndex( return _ffi_api.GetBindingIndex(cfg, seq, block_idx, binding_idx, match_cond) # type: ignore -def DataflowAnalysis( +def dataflow_analysis( cfg: ControlFlowGraph, init: Any, transfer_func: Callable[[GraphBinding, Any], Any], diff --git a/tests/python/relax/test_dataflow_analysis.py b/tests/python/relax/test_dataflow_analysis.py index 86c8c5dc23e2..9d47cbce8ef5 100644 --- a/tests/python/relax/test_dataflow_analysis.py +++ b/tests/python/relax/test_dataflow_analysis.py @@ -19,10 +19,10 @@ from tvm import relax from tvm.relax.analysis.dataflow_analysis import ( ControlFlowGraph, - ExtractCFG, - DataflowAnalysis, + extract_cfg, + dataflow_analysis, BindingNodeKind, - GetBindingIndex, + get_binding_index, ) from tvm.script import ir as I, relax as R import tvm.testing @@ -82,7 +82,7 @@ class TrivialFunc: def main() -> R.Tensor((), "int32"): return R.const(1, dtype="int32") - graph = ExtractCFG(TrivialFunc["main"]) + graph = extract_cfg(TrivialFunc["main"]) assert len(graph.bindings) == 1 assert_pred_succ_lists(graph, [[]]) assert_binding_fields(graph, 0, 0, 0, kind=BindingNodeKind.kSeqBody) @@ -98,7 +98,7 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): q = R.multiply(z, x) return q - graph = ExtractCFG(FuncWithBindings["main"]) + graph = extract_cfg(FuncWithBindings["main"]) assert len(graph.bindings) == 4 assert_pred_succ_lists(graph, [[], [0], [1], [2]]) assert_binding_fields(graph, 0, 0, 0, args=[FuncWithBindings["main"].params[0]]) @@ -125,7 +125,7 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): R.output(u) return u - graph = ExtractCFG(FuncWithDataflow["main"]) + graph = extract_cfg(FuncWithDataflow["main"]) assert len(graph.bindings) == 8 assert_pred_succ_lists(graph, [[], [0], [1], [2], [3], [4], [5], [6]]) assert_binding_fields(graph, 0, 0, 0, args=FuncWithDataflow["main"].params) @@ -153,7 +153,7 @@ def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): z = R.multiply(y, y) return z - graph = ExtractCFG(SimpleBranch["main"]) + graph = extract_cfg(SimpleBranch["main"]) # cond binding + 3 bindings in true branch + true branch end # + 3 bindings in false branch + false branch end + merge + seq body @@ -190,7 +190,7 @@ def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): q = R.add(z, z) return q - graph = ExtractCFG(BranchAndBind["main"]) + graph = extract_cfg(BranchAndBind["main"]) assert len(graph.bindings) == 10 assert_pred_succ_lists(graph, [[], [0], [1], [2], [3], [2], [5], [4, 6], [7], [8]]) assert_binding_fields(graph, 0, 0, 0, args=BranchAndBind["main"].params) @@ -233,7 +233,7 @@ def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): r = R.add(q, q) return r - graph = ExtractCFG(LongBranches["main"]) + graph = extract_cfg(LongBranches["main"]) # empty entry block, one block for each branch, and an empty exit block assert len(graph.bindings) == 19 assert_pred_succ_lists( @@ -307,7 +307,7 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): z = R.multiply(y, y) return z - graph = ExtractCFG(NestedBranches["main"]) + graph = extract_cfg(NestedBranches["main"]) assert len(graph.bindings) == 22 assert_pred_succ_lists( graph, @@ -402,11 +402,11 @@ def check_expected_maps(in_map, out_map): assert out_map[0]["a"] == 1 assert out_map[0]["b"] == 2 - cfg = ExtractCFG(TrivialFunc["main"]) - in_map, out_map = DataflowAnalysis(cfg, init, transfer_func, merge_func, forward=True) + cfg = extract_cfg(TrivialFunc["main"]) + in_map, out_map = dataflow_analysis(cfg, init, transfer_func, merge_func, forward=True) check_expected_maps(in_map, out_map) # backward will just flip in and out - in_map, out_map = DataflowAnalysis(cfg, init, transfer_func, merge_func, forward=False) + in_map, out_map = dataflow_analysis(cfg, init, transfer_func, merge_func, forward=False) check_expected_maps(out_map, in_map) @@ -444,8 +444,8 @@ def merge_func(domain1, domain2): new_domain["merge"] = 1 return new_domain - cfg = ExtractCFG(SimpleBranch["main"]) - in_map, out_map = DataflowAnalysis(cfg, init, transfer_func, merge_func, forward=True) + cfg = extract_cfg(SimpleBranch["main"]) + in_map, out_map = dataflow_analysis(cfg, init, transfer_func, merge_func, forward=True) # start and true branch for i in range(5): assert in_map[i]["a"] == i + 1 @@ -465,7 +465,7 @@ def merge_func(domain1, domain2): assert out_map[10]["a"] == 8 assert out_map[10]["merge"] == 3 - in_map, out_map = DataflowAnalysis(cfg, init, transfer_func, merge_func, forward=False) + in_map, out_map = dataflow_analysis(cfg, init, transfer_func, merge_func, forward=False) # backward direction: start with index 10 # end of seq through false branch for i in range(6): @@ -496,21 +496,21 @@ def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): q = R.add(z, z) return q - graph = ExtractCFG(BranchAndBind["main"]) + graph = extract_cfg(BranchAndBind["main"]) outer_seq = BranchAndBind["main"].body true_seq = BranchAndBind["main"].body.blocks[0].bindings[2].value.true_branch false_seq = BranchAndBind["main"].body.blocks[0].bindings[2].value.false_branch - assert GetBindingIndex(graph, outer_seq, 0, 0) == 0 - assert GetBindingIndex(graph, outer_seq, 0, 1) == 1 - assert GetBindingIndex(graph, outer_seq, 0, 2, match_cond=True) == 2 - assert GetBindingIndex(graph, true_seq, 0, 0) == 3 - assert GetBindingIndex(graph, true_seq, 1, 0) == 4 - assert GetBindingIndex(graph, false_seq, 0, 0) == 5 - assert GetBindingIndex(graph, false_seq, 1, 0) == 6 - assert GetBindingIndex(graph, outer_seq, 0, 2) == 7 # the merge - assert GetBindingIndex(graph, outer_seq, 0, 3) == 8 - assert GetBindingIndex(graph, outer_seq, 1, 0) == 9 + assert get_binding_index(graph, outer_seq, 0, 0) == 0 + assert get_binding_index(graph, outer_seq, 0, 1) == 1 + assert get_binding_index(graph, outer_seq, 0, 2, match_cond=True) == 2 + assert get_binding_index(graph, true_seq, 0, 0) == 3 + assert get_binding_index(graph, true_seq, 1, 0) == 4 + assert get_binding_index(graph, false_seq, 0, 0) == 5 + assert get_binding_index(graph, false_seq, 1, 0) == 6 + assert get_binding_index(graph, outer_seq, 0, 2) == 7 # the merge + assert get_binding_index(graph, outer_seq, 0, 3) == 8 + assert get_binding_index(graph, outer_seq, 1, 0) == 9 if __name__ == "__main__": From 55109849a5b0ec8013f11ac86a6f4596150b63c3 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 6 Sep 2023 17:34:39 -0400 Subject: [PATCH 08/18] Implement liveness analysis --- include/tvm/relax/analysis.h | 12 ++ python/tvm/relax/analysis/analysis.py | 25 +++- src/relax/analysis/liveness.cc | 122 ++++++++++++++++ .../relax/test_analysis_liveness_analysis.py | 131 ++++++++++++++++++ 4 files changed, 289 insertions(+), 1 deletion(-) create mode 100644 src/relax/analysis/liveness.cc create mode 100644 tests/python/relax/test_analysis_liveness_analysis.py diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 6e2209d51950..8abee12a5b98 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -457,6 +457,18 @@ struct VarUsageInfo { */ VarUsageInfo CollectVarUsage(const Expr& expr); +/*! + * \brief Perform a liveness analysis on the function, indicating which variables + * are live at which location in the function. + * + * \param fn The function to be analyzed. + * \return An array of arrays of live variables per binding in the function. + * The array is indexed based on the corresponding control flow graph, + * so use `ExtractCFG` and `GetBindingIndex` to match locations in `fn` + * to indices in the result. + */ +Array> LivenessAnalysis(const Function& fn); + /*! * \brief Remove unused statements inside DataflowBlocks. * diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index 38f5ea2fea0e..4ddcaf2fe3f6 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -21,7 +21,7 @@ configuring the passes and scripting them in Python. """ -from typing import Dict, List, Optional, Union, Callable +from typing import Dict, List, Optional, Set, Union, Callable from enum import IntEnum import tvm @@ -407,6 +407,29 @@ def udchain(dfb: DataflowBlock) -> Dict[Var, List[Var]]: return _ffi_api.udchain(dfb) # type: ignore +def liveness_analysis(func: Function) -> List[Set[Var]]: + """ + Perform a liveness analysis on the given function, returning a set of + the variables live in the given program location. + + Parameters + ---------- + func: Function + The function to be analyzed + + Returns + ------- + ret: List[Set[Var]] + The set of live variables for each binding in the function. + The indexing is determined by the control flow graph, so + use `extract_cfg` and `get_binding_index` to find the index + for a given program location in the list. + """ + live_lists = _ffi_api.LivenessAnalysis(func) + # convert the lists to sets + return [set(live_list) for live_list in live_lists] + + def name_to_binding(func: Function) -> Dict[str, List[Binding]]: """Return a map from variable name to its bindings.""" return _ffi_api.name_to_binding(func) # type: ignore diff --git a/src/relax/analysis/liveness.cc b/src/relax/analysis/liveness.cc new file mode 100644 index 000000000000..1c949cbf28d3 --- /dev/null +++ b/src/relax/analysis/liveness.cc @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/analysis/liveness.cc + * \brief Implementation of liveness analysis + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +// just sets of vars. the bool value is unnecessary +using Domain = Map; + +Domain transfer_func(const GraphBinding& binding, const ObjectRef& input) { + Domain in_domain = Downcast(input); + Domain new_domain(in_domain); + + // 1. If a var that appears in the RHS of the binding, add it (it's live) + // 2. Remove the bound var (it is not live prior to being bound) + Array vars_used; + Optional var_bound; + if (binding->kind == BindingNodeKind::kSeqBody) { + vars_used = AllVars(binding->seq->body); + } else if (binding->kind == BindingNodeKind::kIfCond) { + Binding b = binding->seq->blocks[binding->block_idx]->bindings[binding->binding_idx]; + Expr cond = Downcast(GetBoundValue(b))->cond; + vars_used = AllVars(cond); + } else if (binding->kind == BindingNodeKind::kIfMerge) { + // no vars are used in the merge + vars_used = {}; + // define the merge var + var_bound = binding->seq->blocks[binding->block_idx]->bindings[binding->binding_idx]->var; + } else { + // the ordinary binding case + Binding b = binding->seq->blocks[binding->block_idx]->bindings[binding->binding_idx]; + Expr bound_value = GetBoundValue(b); + // special case: if the RHS is a function literal, we only care about the free vars + // (those captured by the closure) + if (bound_value.as()) { + vars_used = FreeVars(bound_value); + } else { + vars_used = AllVars(bound_value); + } + var_bound = b->var; + } + + for (auto var : vars_used) { + if (!new_domain.count(var)) { + new_domain.Set(var, Bool(true)); + } + } + + // the var bound is killed + if (var_bound.defined()) { + new_domain.erase(var_bound.value()); + } + + // technically, we could kill the args too, + // but they are not actually *bound* at the first binding + + return new_domain; +} + +// simply combine sets of live vars to merge +Domain merge_func(const ObjectRef& domain1, const ObjectRef& domain2) { + Domain merged; + for (auto kv : Downcast(domain1)) { + merged.Set(kv.first, kv.second); + } + for (auto kv : Downcast(domain2)) { + merged.Set(kv.first, kv.second); + } + return merged; +} + +Array> LivenessAnalysis(const Function& func) { + // initial domain is empty + Domain init_domain; + ControlFlowGraph cfg = ExtractCFG(func); + std::pair results = + DataflowAnalysis(cfg, init_domain, transfer_func, merge_func, false); + + // we will return the input map but convert the maps into arrays for simplicity + Array in_map = Downcast>(results.first); + + Array> ret; + for (const Domain& d : in_map) { + Array arr; + for (auto kv : d) { + arr.push_back(kv.first); + } + ret.push_back(arr); + } + return ret; +} + +TVM_REGISTER_GLOBAL("relax.analysis.LivenessAnalysis").set_body_typed(LivenessAnalysis); + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_analysis_liveness_analysis.py b/tests/python/relax/test_analysis_liveness_analysis.py new file mode 100644 index 000000000000..9483073ef08a --- /dev/null +++ b/tests/python/relax/test_analysis_liveness_analysis.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Set +import tvm +import tvm.testing +from tvm.script import ir as I, relax as R +from tvm.relax import Var +from tvm.relax.analysis import liveness_analysis + + +def assert_live_set(live_set: Set[Var], var_names: Set[str]) -> None: + assert len(live_set) == len(var_names) + for var in live_set: + assert var.name_hint in var_names + + +def test_simple_liveness(): + @I.ir_module + class SimpleFunc: + @R.function + def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + y = R.add(x, x) # live: x + z = R.add(y, y) # live: y + return z # live: z + + live_sets = liveness_analysis(SimpleFunc["main"]) + assert_live_set(live_sets[0], {"x"}) + assert_live_set(live_sets[1], {"y"}) + assert_live_set(live_sets[2], {"z"}) + + +def test_liveness_with_branches(): + @I.ir_module + class BranchingFunc: + @R.function + def main( + x: R.Tensor((), dtype="int32"), + y: R.Tensor((), dtype="int32"), + cond: R.Tensor((), dtype="bool"), + ) -> R.Tensor((), dtype="int32"): + z = R.add(x, x) # live: x, y, cond + q = R.add(z, z) # live: y, z, cond + if cond: # live: q, y, cond + r = R.subtract(q, y) # live: q, y + s = R.multiply(r, r) # live: r + # end of seq: the R.multiply will actually be bound to a fresh var + # and s will be used as the binding for the entire If node + else: + r = R.multiply(q, q) # live: q, y + s = R.subtract(r, y) # live: r, y + # end of seq: the R.subtract will actually be bound to a fresh var + # and s will be used as the binding for the entire If node + # merge point: nothing is live (s is the variable bound at the merge) + t = R.add(s, s) # live: s + u = R.multiply(t, s) # live: t, s + return u # live: u + + live_sets = liveness_analysis(BranchingFunc["main"]) + assert_live_set(live_sets[0], {"x", "y", "cond"}) + assert_live_set(live_sets[1], {"y", "z", "cond"}) + assert_live_set(live_sets[2], {"q", "y", "cond"}) + assert_live_set(live_sets[3], {"q", "y"}) + assert_live_set(live_sets[4], {"r"}) + # the name is created by the parser and will be a placeholder so this is the best we can do + assert len(live_sets[5]) == 1 and ( + BranchingFunc["main"].body.blocks[0].bindings[2].value.true_branch.body in live_sets[5] + ) + assert_live_set(live_sets[6], {"q", "y"}) + assert_live_set(live_sets[7], {"r", "y"}) + assert len(live_sets[8]) == 1 and ( + BranchingFunc["main"].body.blocks[0].bindings[2].value.false_branch.body in live_sets[8] + ) + assert_live_set(live_sets[9], {}) + assert_live_set(live_sets[10], {"s"}) + assert_live_set(live_sets[11], {"t", "s"}) + assert_live_set(live_sets[12], {"u"}) + + +def test_liveness_inner_func(): + @I.ir_module + class InnerFunc: + @R.function + def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + y = R.add(x, x) # live: x + z = R.add(y, y) # live: x, y + + # the inner func captures x and y and so counts as a use of both + # live: x, y, z + @R.function + def inner(q: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + # (note: we would need to do liveness analysis of the inner func + # separately to get liveness info for these locations) + r = R.add(x, q) # live: x, y, q (and z from outside) + s = R.multiply(y, r) # live: y, r (and z from outside) + return s # live: s (and z from outside) + + w = inner(z) # live: inner, z + return w # live: w + + live_sets = liveness_analysis(InnerFunc["main"]) + assert_live_set(live_sets[0], {"x"}) + assert_live_set(live_sets[1], {"x", "y"}) + assert_live_set(live_sets[2], {"x", "y", "z"}) + assert_live_set(live_sets[3], {"inner", "z"}) + assert_live_set(live_sets[4], {"w"}) + + # let's also analyze the inner func (note: we don't have a way to indicate + # that z is live from outside the func) + inner_live = liveness_analysis(InnerFunc["main"].body.blocks[0].bindings[2].value) + assert_live_set(inner_live[0], {"x", "y", "q"}) + assert_live_set(inner_live[1], {"y", "r"}) + assert_live_set(inner_live[2], {"s"}) + + +if __name__ == "__main__": + tvm.testing.main() From adb2d7d48389c634de0863799474dd657b6d5dae Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 6 Sep 2023 21:09:31 -0400 Subject: [PATCH 09/18] Trailing newline --- include/tvm/relax/dataflow_analysis.h | 2 +- src/relax/analysis/dataflow_analysis.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tvm/relax/dataflow_analysis.h b/include/tvm/relax/dataflow_analysis.h index 65038be65e05..a80a660ca8b7 100644 --- a/include/tvm/relax/dataflow_analysis.h +++ b/include/tvm/relax/dataflow_analysis.h @@ -193,4 +193,4 @@ size_t GetBindingIndex(const ControlFlowGraph& cfg, const SeqExpr& seq, size_t b } // namespace relax } // namespace tvm -#endif \ No newline at end of file +#endif diff --git a/src/relax/analysis/dataflow_analysis.cc b/src/relax/analysis/dataflow_analysis.cc index 46ada151ddbb..df3d8d5b013e 100644 --- a/src/relax/analysis/dataflow_analysis.cc +++ b/src/relax/analysis/dataflow_analysis.cc @@ -246,4 +246,4 @@ TVM_REGISTER_GLOBAL("relax.analysis.GetBindingIndex") }); } // namespace relax -} // namespace tvm \ No newline at end of file +} // namespace tvm From d1438444e91241c58b0fa51cea4ea167adc04193 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 7 Sep 2023 13:25:58 -0400 Subject: [PATCH 10/18] Python style fixes --- .../tvm/relax/analysis/dataflow_analysis.py | 18 +++--- tests/python/relax/test_dataflow_analysis.py | 64 +++++++++---------- 2 files changed, 41 insertions(+), 41 deletions(-) diff --git a/python/tvm/relax/analysis/dataflow_analysis.py b/python/tvm/relax/analysis/dataflow_analysis.py index a7a2446a47bf..9c233f27edc0 100644 --- a/python/tvm/relax/analysis/dataflow_analysis.py +++ b/python/tvm/relax/analysis/dataflow_analysis.py @@ -21,15 +21,15 @@ from typing import Any, Callable, List, Tuple import tvm from tvm.ir.base import Node -from tvm.relax.expr import Expr, SeqExpr, Function, Var +from tvm.relax.expr import SeqExpr, Function, Var from . import _ffi_api class BindingNodeKind(Enum): - kBinding = 0 - kIfCond = 1 - kIfMerge = 2 - kSeqBody = 3 + Binding = 0 + IfCond = 1 + IfMerge = 2 + SeqBody = 3 @tvm._ffi.register_object("relax.analysis.GraphBinding") @@ -74,7 +74,7 @@ def __init__( If conditions, If merges (the var bound to the result of the If node), and the body of the SeqExpr. """ - return self.__init_handle_by_constructor__( + self.__init_handle_by_constructor__( _ffi_api.GraphBinding, seq, args, @@ -109,7 +109,7 @@ def __init__( if len(bindings) != len(preds) or len(bindings) != len(succs): raise ValueError("The lengths of blocks, preds, and succs must all match.") - return self.__init_handle_by_constructor__( + self.__init_handle_by_constructor__( _ffi_api.ControlFlowGraph, bindings, preds, succs ) # type: ignore @@ -136,8 +136,8 @@ def get_binding_index( cfg: ControlFlowGraph, seq: SeqExpr, block_idx: int, binding_idx: int, match_cond: bool = False ) -> int: """ - Helper function. Given a control flow graph and a seq expression with a block index and binding index, - return the index of the corresponding GraphBinding in the CFG + Helper function. Given a control flow graph and a seq expression with a block index + and binding index, return the index of the corresponding GraphBinding in the CFG Parameters ---------- diff --git a/tests/python/relax/test_dataflow_analysis.py b/tests/python/relax/test_dataflow_analysis.py index 9d47cbce8ef5..d61dfcc509e7 100644 --- a/tests/python/relax/test_dataflow_analysis.py +++ b/tests/python/relax/test_dataflow_analysis.py @@ -50,7 +50,7 @@ def assert_binding_fields( idx: int, block_idx: int, binding_idx: int, - kind: BindingNodeKind = BindingNodeKind.kBinding, + kind: BindingNodeKind = BindingNodeKind.Binding, args: Optional[List[relax.Var]] = None, ): binding = graph.bindings[idx] @@ -85,7 +85,7 @@ def main() -> R.Tensor((), "int32"): graph = extract_cfg(TrivialFunc["main"]) assert len(graph.bindings) == 1 assert_pred_succ_lists(graph, [[]]) - assert_binding_fields(graph, 0, 0, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 0, 0, 0, kind=BindingNodeKind.SeqBody) def test_sequence_of_bindings(): @@ -104,7 +104,7 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): assert_binding_fields(graph, 0, 0, 0, args=[FuncWithBindings["main"].params[0]]) assert_binding_fields(graph, 1, 0, 1) assert_binding_fields(graph, 2, 0, 2) - assert_binding_fields(graph, 3, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 3, 1, 0, kind=BindingNodeKind.SeqBody) def test_dataflow_block(): @@ -135,7 +135,7 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): assert_binding_fields(graph, 4, 2, 0) assert_binding_fields(graph, 5, 3, 0) assert_binding_fields(graph, 6, 3, 1) - assert_binding_fields(graph, 7, 4, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 7, 4, 0, kind=BindingNodeKind.SeqBody) def test_simple_branch(): @@ -161,18 +161,18 @@ def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): assert_pred_succ_lists(graph, [[], [0], [1], [2], [3], [0], [5], [6], [7], [4, 8], [9]]) assert_binding_fields( - graph, 0, 0, 0, kind=BindingNodeKind.kIfCond, args=SimpleBranch["main"].params + graph, 0, 0, 0, kind=BindingNodeKind.IfCond, args=SimpleBranch["main"].params ) assert_binding_fields(graph, 1, 0, 0) assert_binding_fields(graph, 2, 0, 1) assert_binding_fields(graph, 3, 0, 2) - assert_binding_fields(graph, 4, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 4, 1, 0, kind=BindingNodeKind.SeqBody) assert_binding_fields(graph, 5, 0, 0) assert_binding_fields(graph, 6, 0, 1) assert_binding_fields(graph, 7, 0, 2) - assert_binding_fields(graph, 8, 1, 0, kind=BindingNodeKind.kSeqBody) - assert_binding_fields(graph, 9, 0, 0, kind=BindingNodeKind.kIfMerge) - assert_binding_fields(graph, 10, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 8, 1, 0, kind=BindingNodeKind.SeqBody) + assert_binding_fields(graph, 9, 0, 0, kind=BindingNodeKind.IfMerge) + assert_binding_fields(graph, 10, 1, 0, kind=BindingNodeKind.SeqBody) assert_distinct_seqs(graph, [0, 9], [1, 4], [5, 8]) @@ -195,14 +195,14 @@ def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): assert_pred_succ_lists(graph, [[], [0], [1], [2], [3], [2], [5], [4, 6], [7], [8]]) assert_binding_fields(graph, 0, 0, 0, args=BranchAndBind["main"].params) assert_binding_fields(graph, 1, 0, 1) - assert_binding_fields(graph, 2, 0, 2, kind=BindingNodeKind.kIfCond) + assert_binding_fields(graph, 2, 0, 2, kind=BindingNodeKind.IfCond) assert_binding_fields(graph, 3, 0, 0) - assert_binding_fields(graph, 4, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 4, 1, 0, kind=BindingNodeKind.SeqBody) assert_binding_fields(graph, 5, 0, 0) - assert_binding_fields(graph, 6, 1, 0, kind=BindingNodeKind.kSeqBody) - assert_binding_fields(graph, 7, 0, 2, kind=BindingNodeKind.kIfMerge) + assert_binding_fields(graph, 6, 1, 0, kind=BindingNodeKind.SeqBody) + assert_binding_fields(graph, 7, 0, 2, kind=BindingNodeKind.IfMerge) assert_binding_fields(graph, 8, 0, 3) - assert_binding_fields(graph, 9, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 9, 1, 0, kind=BindingNodeKind.SeqBody) assert_distinct_seqs(graph, [0, 2, 7, 9], [3, 4], [5, 6]) @@ -262,7 +262,7 @@ def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): ) assert_binding_fields( - graph, 0, 0, 0, kind=BindingNodeKind.kIfCond, args=LongBranches["main"].params + graph, 0, 0, 0, kind=BindingNodeKind.IfCond, args=LongBranches["main"].params ) assert_binding_fields(graph, 1, 0, 0) assert_binding_fields(graph, 2, 0, 1) @@ -271,7 +271,7 @@ def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): assert_binding_fields(graph, 5, 1, 2) assert_binding_fields(graph, 6, 2, 0) assert_binding_fields(graph, 7, 2, 1) - assert_binding_fields(graph, 8, 3, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 8, 3, 0, kind=BindingNodeKind.SeqBody) assert_binding_fields(graph, 9, 0, 0) assert_binding_fields(graph, 10, 0, 1) assert_binding_fields(graph, 11, 1, 0) @@ -279,9 +279,9 @@ def main(cond: R.Tensor((), "bool")) -> R.Tensor((), "int32"): assert_binding_fields(graph, 13, 1, 2) assert_binding_fields(graph, 14, 2, 0) assert_binding_fields(graph, 15, 2, 1) - assert_binding_fields(graph, 16, 3, 0, kind=BindingNodeKind.kSeqBody) - assert_binding_fields(graph, 17, 0, 0, kind=BindingNodeKind.kIfMerge) - assert_binding_fields(graph, 18, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 16, 3, 0, kind=BindingNodeKind.SeqBody) + assert_binding_fields(graph, 17, 0, 0, kind=BindingNodeKind.IfMerge) + assert_binding_fields(graph, 18, 1, 0, kind=BindingNodeKind.SeqBody) assert_distinct_seqs(graph, [0, 17, 18], [1, 8], [9, 16]) @@ -338,27 +338,27 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): ) assert_binding_fields(graph, 0, 0, 0, args=NestedBranches["main"].params) - assert_binding_fields(graph, 1, 0, 1, kind=BindingNodeKind.kIfCond) + assert_binding_fields(graph, 1, 0, 1, kind=BindingNodeKind.IfCond) assert_binding_fields(graph, 2, 0, 0) - assert_binding_fields(graph, 3, 0, 1, kind=BindingNodeKind.kIfCond) + assert_binding_fields(graph, 3, 0, 1, kind=BindingNodeKind.IfCond) assert_binding_fields(graph, 4, 0, 0) - assert_binding_fields(graph, 5, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 5, 1, 0, kind=BindingNodeKind.SeqBody) assert_binding_fields(graph, 6, 0, 0) - assert_binding_fields(graph, 7, 1, 0, kind=BindingNodeKind.kSeqBody) - assert_binding_fields(graph, 8, 0, 1, kind=BindingNodeKind.kIfMerge) + assert_binding_fields(graph, 7, 1, 0, kind=BindingNodeKind.SeqBody) + assert_binding_fields(graph, 8, 0, 1, kind=BindingNodeKind.IfMerge) assert_binding_fields(graph, 9, 0, 2) - assert_binding_fields(graph, 10, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 10, 1, 0, kind=BindingNodeKind.SeqBody) assert_binding_fields(graph, 11, 0, 0) - assert_binding_fields(graph, 12, 0, 1, kind=BindingNodeKind.kIfCond) + assert_binding_fields(graph, 12, 0, 1, kind=BindingNodeKind.IfCond) assert_binding_fields(graph, 13, 0, 0) - assert_binding_fields(graph, 14, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 14, 1, 0, kind=BindingNodeKind.SeqBody) assert_binding_fields(graph, 15, 0, 0) - assert_binding_fields(graph, 16, 1, 0, kind=BindingNodeKind.kSeqBody) - assert_binding_fields(graph, 17, 0, 1, kind=BindingNodeKind.kIfMerge) + assert_binding_fields(graph, 16, 1, 0, kind=BindingNodeKind.SeqBody) + assert_binding_fields(graph, 17, 0, 1, kind=BindingNodeKind.IfMerge) assert_binding_fields(graph, 18, 0, 2) - assert_binding_fields(graph, 19, 1, 0, kind=BindingNodeKind.kSeqBody) - assert_binding_fields(graph, 20, 0, 1, kind=BindingNodeKind.kIfMerge) - assert_binding_fields(graph, 21, 1, 0, kind=BindingNodeKind.kSeqBody) + assert_binding_fields(graph, 19, 1, 0, kind=BindingNodeKind.SeqBody) + assert_binding_fields(graph, 20, 0, 1, kind=BindingNodeKind.IfMerge) + assert_binding_fields(graph, 21, 1, 0, kind=BindingNodeKind.SeqBody) assert_distinct_seqs( graph, From 1cd9d502215e5182c5df3b3fbbe5becf38944f5b Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 7 Sep 2023 17:10:57 -0400 Subject: [PATCH 11/18] Header file style fixes --- include/tvm/relax/dataflow_analysis.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/include/tvm/relax/dataflow_analysis.h b/include/tvm/relax/dataflow_analysis.h index a80a660ca8b7..aefbf920d1c9 100644 --- a/include/tvm/relax/dataflow_analysis.h +++ b/include/tvm/relax/dataflow_analysis.h @@ -32,6 +32,8 @@ #include #include +#include + namespace tvm { namespace relax { @@ -193,4 +195,4 @@ size_t GetBindingIndex(const ControlFlowGraph& cfg, const SeqExpr& seq, size_t b } // namespace relax } // namespace tvm -#endif +#endif // TVM_RELAX_DATAFLOW_ANALYSIS_H_ From bf6f0162dbb3c30ed6e583e90a1519ad5609fec3 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 28 Nov 2023 21:22:36 -0500 Subject: [PATCH 12/18] Remove redundant check --- src/relax/analysis/liveness.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/relax/analysis/liveness.cc b/src/relax/analysis/liveness.cc index 1c949cbf28d3..2bdc476cbc68 100644 --- a/src/relax/analysis/liveness.cc +++ b/src/relax/analysis/liveness.cc @@ -67,9 +67,7 @@ Domain transfer_func(const GraphBinding& binding, const ObjectRef& input) { } for (auto var : vars_used) { - if (!new_domain.count(var)) { - new_domain.Set(var, Bool(true)); - } + new_domain.Set(var, Bool(true)); } // the var bound is killed From 171fac615a28f5f3e7699650184511380653887a Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 28 Nov 2023 21:40:37 -0500 Subject: [PATCH 13/18] Add liveness analysis to __init__ for analysis --- python/tvm/relax/analysis/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relax/analysis/__init__.py b/python/tvm/relax/analysis/__init__.py index d8454a02cc84..d2c4c889b5d7 100644 --- a/python/tvm/relax/analysis/__init__.py +++ b/python/tvm/relax/analysis/__init__.py @@ -32,6 +32,7 @@ get_static_type, get_var2val, has_reshape_pattern, + liveness_analysis, name_to_binding, post_order_visit, remove_all_unused, From 4289b5acbff74d6c58ac4c97361a3baa9ea0d5ee Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 28 Nov 2023 22:01:52 -0500 Subject: [PATCH 14/18] No need to distinguish between FreeVars and AllVars --- src/relax/analysis/liveness.cc | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/relax/analysis/liveness.cc b/src/relax/analysis/liveness.cc index 2bdc476cbc68..bad7aaf17eaa 100644 --- a/src/relax/analysis/liveness.cc +++ b/src/relax/analysis/liveness.cc @@ -56,13 +56,11 @@ Domain transfer_func(const GraphBinding& binding, const ObjectRef& input) { // the ordinary binding case Binding b = binding->seq->blocks[binding->block_idx]->bindings[binding->binding_idx]; Expr bound_value = GetBoundValue(b); - // special case: if the RHS is a function literal, we only care about the free vars - // (those captured by the closure) - if (bound_value.as()) { - vars_used = FreeVars(bound_value); - } else { - vars_used = AllVars(bound_value); - } + // For a function literal, we only care about the free vars + // (those captured by the closure). + // In all other cases, we want any var, but since the RHS would not contain + // any bindings of its own, that turns out to be the same as the free vars. + vars_used = FreeVars(bound_value); var_bound = b->var; } From 44da8010b988d4887b7938cd82262b522078af0c Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 28 Nov 2023 22:03:27 -0500 Subject: [PATCH 15/18] Use enum class instead of an int for enums --- include/tvm/relax/dataflow_analysis.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/relax/dataflow_analysis.h b/include/tvm/relax/dataflow_analysis.h index aefbf920d1c9..68e0b8a36d94 100644 --- a/include/tvm/relax/dataflow_analysis.h +++ b/include/tvm/relax/dataflow_analysis.h @@ -48,7 +48,7 @@ namespace relax { * the SeqExprs in the true and false branches) * 4. The body expression in a SeqExpr (not actually bound) */ -enum BindingNodeKind : int { kBinding = 0, kIfCond = 1, kIfMerge = 2, kSeqBody = 3 }; +enum class BindingNodeKind { kBinding = 0, kIfCond = 1, kIfMerge = 2, kSeqBody = 3 }; class GraphBindingNode : public Object { public: From 8516d23ce1d9d5ab0796959bedfba68524314bfe Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 28 Nov 2023 22:09:48 -0500 Subject: [PATCH 16/18] Map over the results of the liveness analysis when doing data structure conversion for safety --- src/relax/analysis/liveness.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/relax/analysis/liveness.cc b/src/relax/analysis/liveness.cc index bad7aaf17eaa..e3c4f7d0f6f1 100644 --- a/src/relax/analysis/liveness.cc +++ b/src/relax/analysis/liveness.cc @@ -99,7 +99,12 @@ Array> LivenessAnalysis(const Function& func) { DataflowAnalysis(cfg, init_domain, transfer_func, merge_func, false); // we will return the input map but convert the maps into arrays for simplicity - Array in_map = Downcast>(results.first); + + // The map is done for safety, since directly doing Downcast>(results.first) + // would *not* check the contents of results.first. + Array res_objs = Downcast>(results.first); + Array in_map = + res_objs.Map([](const ObjectRef& obj) { return Downcast(obj); }); Array> ret; for (const Domain& d : in_map) { From 90e09b426f54fe9e76d47aa3ca187c5c808c3a60 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 28 Nov 2023 22:24:03 -0500 Subject: [PATCH 17/18] Use constructors for CFG structures instead of Create functions --- include/tvm/relax/dataflow_analysis.h | 9 ++++--- src/relax/analysis/dataflow_analysis.cc | 32 +++++++++++-------------- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/include/tvm/relax/dataflow_analysis.h b/include/tvm/relax/dataflow_analysis.h index 68e0b8a36d94..af5823a82b84 100644 --- a/include/tvm/relax/dataflow_analysis.h +++ b/include/tvm/relax/dataflow_analysis.h @@ -99,8 +99,8 @@ class GraphBinding : public ObjectRef { * \param kind: The kind of binding this is. (Used especially to distinguish If node conditions * from the merge after the If) */ - TVM_DLL static GraphBinding Create(const SeqExpr& seq, const Array& args, size_t block_idx, - size_t binding_idx, BindingNodeKind kind); + TVM_DLL GraphBinding(const SeqExpr& seq, const Array& args, size_t block_idx, + size_t binding_idx, BindingNodeKind kind); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(GraphBinding, ObjectRef, GraphBindingNode); }; @@ -136,9 +136,8 @@ class ControlFlowGraph : public ObjectRef { * \param preds: List of lists of predecessors to each binding. * \param succs: List of lists of successors to each binding. */ - TVM_DLL static ControlFlowGraph Create(const Array& bindings, - const Array>& preds, - const Array>& succs); + TVM_DLL ControlFlowGraph(const Array& bindings, const Array>& preds, + const Array>& succs); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ControlFlowGraph, ObjectRef, ControlFlowGraphNode); }; diff --git a/src/relax/analysis/dataflow_analysis.cc b/src/relax/analysis/dataflow_analysis.cc index df3d8d5b013e..151f4d18c749 100644 --- a/src/relax/analysis/dataflow_analysis.cc +++ b/src/relax/analysis/dataflow_analysis.cc @@ -31,27 +31,27 @@ namespace relax { TVM_REGISTER_NODE_TYPE(GraphBindingNode); -GraphBinding GraphBinding::Create(const SeqExpr& seq, const Array& args, size_t block_idx, - size_t binding_idx, BindingNodeKind kind) { +GraphBinding::GraphBinding(const SeqExpr& seq, const Array& args, size_t block_idx, + size_t binding_idx, BindingNodeKind kind) { ObjectPtr n = make_object(); n->seq = seq; n->args = args; n->block_idx = block_idx; n->binding_idx = binding_idx; n->kind = kind; - return GraphBinding(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(ControlFlowGraphNode); -ControlFlowGraph ControlFlowGraph::Create(const Array& bindings, - const Array>& preds, - const Array>& succs) { +ControlFlowGraph::ControlFlowGraph(const Array& bindings, + const Array>& preds, + const Array>& succs) { ObjectPtr n = make_object(); n->bindings = bindings; n->preds = preds; n->succs = succs; - return ControlFlowGraph(n); + data_ = std::move(n); } // Extracts a basic block and updates the running lists bindings, preds, and succs. @@ -64,7 +64,7 @@ size_t ExtractCFGHelper(const SeqExpr& seq, const Array& args, size_t block std::vector>* succs) { // case 1: We're past the end -> this is the block body (base case) if (block_idx == seq->blocks.size()) { - bindings->push_back(GraphBinding::Create(seq, args, block_idx, 0U, BindingNodeKind::kSeqBody)); + bindings->push_back(GraphBinding(seq, args, block_idx, 0U, BindingNodeKind::kSeqBody)); preds->push_back(current_preds); // the final binding has no successors succs->push_back({}); @@ -76,8 +76,7 @@ size_t ExtractCFGHelper(const SeqExpr& seq, const Array& args, size_t block // case 2: Ordinary binding if (!binding_value.as()) { - bindings->push_back( - GraphBinding::Create(seq, args, block_idx, binding_idx, BindingNodeKind::kBinding)); + bindings->push_back(GraphBinding(seq, args, block_idx, binding_idx, BindingNodeKind::kBinding)); size_t idx = bindings->size() - 1; preds->push_back(current_preds); // successor: the next binding (there will always be at least one binding after this, @@ -87,8 +86,7 @@ size_t ExtractCFGHelper(const SeqExpr& seq, const Array& args, size_t block // case 3: dealing with a branch auto if_node = Downcast(binding_value); // start with the cond node - bindings->push_back( - GraphBinding::Create(seq, args, block_idx, binding_idx, BindingNodeKind::kIfCond)); + bindings->push_back(GraphBinding(seq, args, block_idx, binding_idx, BindingNodeKind::kIfCond)); size_t idx = bindings->size() - 1; preds->push_back(current_preds); // there will be another successor, which we will add after recursing down the branches @@ -99,8 +97,7 @@ size_t ExtractCFGHelper(const SeqExpr& seq, const Array& args, size_t block size_t final_false_idx = ExtractCFGHelper(Downcast(if_node->false_branch), {}, 0U, 0U, {idx}, bindings, preds, succs); // now create the merge - bindings->push_back( - GraphBinding::Create(seq, {}, block_idx, binding_idx, BindingNodeKind::kIfMerge)); + bindings->push_back(GraphBinding(seq, {}, block_idx, binding_idx, BindingNodeKind::kIfMerge)); size_t merge_idx = bindings->size() - 1; preds->push_back({final_true_idx, final_false_idx}); succs->push_back({merge_idx + 1}); @@ -142,7 +139,7 @@ ControlFlowGraph ExtractCFG(const Function& func) { } succ_arr.push_back(succ_ints); } - return ControlFlowGraph::Create(Array(bindings), pred_arr, succ_arr); + return ControlFlowGraph(Array(bindings), pred_arr, succ_arr); } std::pair, Array> DataflowAnalysis( @@ -219,14 +216,13 @@ size_t GetBindingIndex(const ControlFlowGraph& cfg, const SeqExpr& seq, size_t b TVM_REGISTER_GLOBAL("relax.analysis.GraphBinding") .set_body_typed([](const SeqExpr& seq, const Array& args, size_t block_idx, size_t binding_idx, int kind) { - return GraphBinding::Create(seq, args, block_idx, binding_idx, - static_cast(kind)); + return GraphBinding(seq, args, block_idx, binding_idx, static_cast(kind)); }); TVM_REGISTER_GLOBAL("relax.analysis.ControlFlowGraph") .set_body_typed([](const Array& blocks, const Array>& preds, const Array>& succs) { - return ControlFlowGraph::Create(blocks, preds, succs); + return ControlFlowGraph(blocks, preds, succs); }); TVM_REGISTER_GLOBAL("relax.analysis.ExtractCFG").set_body_typed(ExtractCFG); From 7679535a93a6e00cfa3087da04f611812278ba61 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 29 Nov 2023 14:17:01 -0500 Subject: [PATCH 18/18] Formatting --- src/relax/analysis/liveness.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/relax/analysis/liveness.cc b/src/relax/analysis/liveness.cc index e3c4f7d0f6f1..548026e3bb67 100644 --- a/src/relax/analysis/liveness.cc +++ b/src/relax/analysis/liveness.cc @@ -103,8 +103,7 @@ Array> LivenessAnalysis(const Function& func) { // The map is done for safety, since directly doing Downcast>(results.first) // would *not* check the contents of results.first. Array res_objs = Downcast>(results.first); - Array in_map = - res_objs.Map([](const ObjectRef& obj) { return Downcast(obj); }); + Array in_map = res_objs.Map([](const ObjectRef& obj) { return Downcast(obj); }); Array> ret; for (const Domain& d : in_map) {