From 1d602b39301889a4ee43a8a174124a8e8298fab7 Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Mon, 20 Mar 2023 10:43:19 -0700 Subject: [PATCH 1/3] [Unity][Pass] Add pass for CSE within dataflow --- include/tvm/relax/transform.h | 7 + python/tvm/relax/transform/transform.py | 17 ++ .../transform/eliminate_common_subexpr.cc | 179 ++++++++++++++++++ tests/python/relax/test_transform_cse.py | 66 +++++++ 4 files changed, 269 insertions(+) create mode 100644 src/relax/transform/eliminate_common_subexpr.cc create mode 100644 tests/python/relax/test_transform_cse.py diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 5a21f76b0b4e..2edcab454d1f 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -137,6 +137,13 @@ TVM_DLL Pass Normalize(); */ TVM_DLL Pass CanonicalizeBindings(); +/*! + * Eliminate common subexpressions within dataflow blocks. + * \param fskip The callback function that decides whether an expression should be skipped. + * \return The pass that eliminates common subexpressions. + */ +TVM_DLL Pass EliminateCommonSubexpr(runtime::TypedPackedFunc fskip); + /*! * \brief Bind params of function of the module to constant tensors. * diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 95f81f7e6cde..c179ad08ff88 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -100,6 +100,23 @@ def CanonicalizeBindings() -> tvm.ir.transform.Pass: return _ffi_api.CanonicalizeBindings() # type: ignore +def EliminateCommonSubexpr(fskip=None): + """Eliminate common subexpressions within dataflow blocks. + + Parameters + ---------- + fskip: Callable + The callback function that decides whether an expression should be + skipped. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass that eliminates common subexpressions. + """ + return _ffi_api.EliminateCommonSubexpr(fskip) + + def RewriteDataflowReshape() -> tvm.ir.transform.Pass: """Convert all reshape-like call_tir to VM reshape operator call. The VM reshape operator calls will be further lowered to a CreateView diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc new file mode 100644 index 000000000000..ad4f1626a098 --- /dev/null +++ b/src/relax/transform/eliminate_common_subexpr.cc @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/eliminate_common_subexpr.cc + * \brief Eliminrate common subexpression pass. + * + * Currently it removes common subexpressions within a DataflowBlock. + */ +#include +#include + +namespace tvm { +namespace relax { + +/*! + * \brief Check if two expressions are equal scalars. + * \param a The expression to be checked. + * \param b The expression to be checked + * \return Whether two expressions are equal scalars. + */ +static bool IsEqualScalar(const Expr& a, const Expr& b) { + const auto* constant_a = a.as(); + const auto* constant_b = b.as(); + if (!constant_a || !constant_b || !constant_a->is_scalar() || !constant_b->is_scalar()) { + return false; + } + return tvm::StructuralEqual()(a, b); +} + +class CommonSubexprEliminator : public ExprMutator { + public: + explicit CommonSubexprEliminator(runtime::TypedPackedFunc fskip) : fskip_(fskip) {} + + private: + void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { + auto post = VisitExpr(GetRef(call_node)); + auto new_val = Rewrite_(call_node, post); + return ExprMutator::VisitBinding_(binding, new_val.as()); + } + + void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) final { + auto post = VisitExpr(GetRef(val)); + return ExprMutator::VisitBinding_(binding, val); + } + + void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) final { + auto post = VisitExpr(GetRef(val)); + auto new_val = Rewrite_(val, post); + return ExprMutator::VisitBinding_(binding, new_val.as()); + } + + private: + Expr Rewrite_(const CallNode* call, const Expr& post) { + Expr new_expr = post; + const CallNode* new_call = new_expr.as(); + ICHECK(new_call); + const OpNode* op = new_call->op.as(); + StructuralEqual attrs_equal; + + if (new_call->args.size() == 0 || op == nullptr) { + return new_expr; + } + if (fskip_ != nullptr && fskip_(new_expr)) { + return new_expr; + } + + auto it = expr_map_.find(new_call->op); + if (it != expr_map_.end()) { + for (const Expr& candidate_expr : it->second) { + if (const CallNode* candidate = candidate_expr.as()) { + bool is_equivalent = true; + if (!attrs_equal(new_call->attrs, candidate->attrs)) { + continue; + } + for (size_t i = 0; i < new_call->args.size(); i++) { + if (!IsEquivalent(new_call->args[i], candidate->args[i])) { + is_equivalent = false; + break; + } + } + if (!is_equivalent) continue; + return GetRef(candidate); + } + } + } + expr_map_[new_call->op].push_back(new_expr); + return new_expr; + } + + Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) { + Expr new_expr = post; + const TupleGetItemNode* new_tuple_item = new_expr.as(); + ICHECK(new_tuple_item); + + if (fskip_ != nullptr && fskip_(new_expr)) { + return new_expr; + } + + auto it = expr_map_.find(new_tuple_item->tuple); + if (it != expr_map_.end()) { + for (const Expr& candidate_expr : it->second) { + if (const TupleGetItemNode* candidate = candidate_expr.as()) { + if (new_tuple_item->index == candidate->index) { + return GetRef(candidate); + } + } + } + } + expr_map_[new_tuple_item->tuple].push_back(new_expr); + return new_expr; + } + + bool IsEquivalent(const Expr& arg, const Expr& candidate_arg) { + if (arg->IsInstance() && candidate_arg->IsInstance()) { + const TupleNode* arg_node = arg.as(); + const TupleNode* candidate_arg_node = candidate_arg.as(); + + if (arg_node->fields.size() != candidate_arg_node->fields.size()) { + return false; + } + + for (size_t i = 0; i < arg_node->fields.size(); i++) { + if (!arg_node->fields[i].same_as(candidate_arg_node->fields[i]) && + !IsEqualScalar(arg_node->fields[i], candidate_arg_node->fields[i])) { + return false; + } + } + } else { + if (!arg.same_as(candidate_arg) && !IsEqualScalar(arg, candidate_arg)) { + return false; + } + } + + return true; + } + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> expr_map_; + runtime::TypedPackedFunc fskip_; +}; + +DataflowBlock EliminateCommonSubexpr(const DataflowBlock& df_block, PackedFunc fskip) { + CommonSubexprEliminator mutator(fskip); + return Downcast(mutator.VisitBindingBlock(df_block)); +} + +namespace transform { + +Pass EliminateCommonSubexpr(runtime::TypedPackedFunc fskip) { + runtime::TypedPackedFunc pass_func = + [=](DataflowBlock df_block, IRModule m, PassContext pc) { + return Downcast(EliminateCommonSubexpr(df_block, fskip)); + }; + return CreateDataflowBlockPass(pass_func, 1, "EliminateCommonSubexpr", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.EliminateCommonSubexpr") + .set_body_typed(EliminateCommonSubexpr); + +} // namespace transform + +} // namespace relax +} // namespace tvm \ No newline at end of file diff --git a/tests/python/relax/test_transform_cse.py b/tests/python/relax/test_transform_cse.py new file mode 100644 index 000000000000..05904ae79f25 --- /dev/null +++ b/tests/python/relax/test_transform_cse.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test eliminate common subexpr pass""" +import tvm +import tvm.testing +from tvm.relax.transform import EliminateCommonSubexpr +from tvm.script.parser import ir as I, relax as R, tir as T + + +def verify(input, expected): + tvm.ir.assert_structural_equal(EliminateCommonSubexpr()(input), expected) + + +def test_simple(): + @I.ir_module + class Before: + @R.function + def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")): + with R.dataflow(): + lv0 = R.add(x, y) + lv1 = R.add(x, y) + gv = R.multiply(lv0, lv1) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")): + with R.dataflow(): + lv0 = R.add(x, y) + gv = R.multiply(lv0, lv0) + R.output(gv) + return gv + + verify(Before, Expected) + + +def test_skip_callback(): + pass + + +def test_tuple_get_time(): + pass + + +def test_tuple_arg(): + pass + + +if __name__ == "__main__": + tvm.testing.main() From 2d322c151fffb68e98c5a630d66d6acfc6d289aa Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 21 Mar 2023 18:39:48 -0400 Subject: [PATCH 2/3] Fill in CSE definition and test cases --- include/tvm/relax/transform.h | 6 +- python/tvm/relax/transform/transform.py | 11 +- .../transform/eliminate_common_subexpr.cc | 234 ++++++++++-------- tests/python/relax/test_transform_cse.py | 134 +++++++++- 4 files changed, 267 insertions(+), 118 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 2edcab454d1f..ac481d5006ba 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -139,10 +139,12 @@ TVM_DLL Pass CanonicalizeBindings(); /*! * Eliminate common subexpressions within dataflow blocks. - * \param fskip The callback function that decides whether an expression should be skipped. * \return The pass that eliminates common subexpressions. + * + * \note For functions local to dataflow blocks, this pass performs + * CSE *within* those functions. */ -TVM_DLL Pass EliminateCommonSubexpr(runtime::TypedPackedFunc fskip); +TVM_DLL Pass EliminateCommonSubexpr(); /*! * \brief Bind params of function of the module to constant tensors. diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index c179ad08ff88..f243dc16cf77 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -100,21 +100,18 @@ def CanonicalizeBindings() -> tvm.ir.transform.Pass: return _ffi_api.CanonicalizeBindings() # type: ignore -def EliminateCommonSubexpr(fskip=None): +def EliminateCommonSubexpr() -> DataflowBlockPass: """Eliminate common subexpressions within dataflow blocks. - Parameters - ---------- - fskip: Callable - The callback function that decides whether an expression should be - skipped. + Note: For functions local to dataflow blocks, this pass performs + CSE *within* those functions Returns ------- ret : tvm.transform.Pass The registered pass that eliminates common subexpressions. """ - return _ffi_api.EliminateCommonSubexpr(fskip) + return _ffi_api.EliminateCommonSubexpr() # type: ignore def RewriteDataflowReshape() -> tvm.ir.transform.Pass: diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc index ad4f1626a098..f5f932c0b603 100644 --- a/src/relax/transform/eliminate_common_subexpr.cc +++ b/src/relax/transform/eliminate_common_subexpr.cc @@ -30,142 +30,172 @@ namespace tvm { namespace relax { -/*! - * \brief Check if two expressions are equal scalars. - * \param a The expression to be checked. - * \param b The expression to be checked - * \return Whether two expressions are equal scalars. - */ -static bool IsEqualScalar(const Expr& a, const Expr& b) { - const auto* constant_a = a.as(); - const auto* constant_b = b.as(); - if (!constant_a || !constant_b || !constant_a->is_scalar() || !constant_b->is_scalar()) { - return false; - } - return tvm::StructuralEqual()(a, b); -} - -class CommonSubexprEliminator : public ExprMutator { +class SubexprCounter : public ExprVisitor { public: - explicit CommonSubexprEliminator(runtime::TypedPackedFunc fskip) : fskip_(fskip) {} - - private: - void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { - auto post = VisitExpr(GetRef(call_node)); - auto new_val = Rewrite_(call_node, post); - return ExprMutator::VisitBinding_(binding, new_val.as()); + // overriding VisitExpr ensures we do this for every subexpression + void VisitExpr(const Expr& e) override { + // Cases we ignore because we will not substitute them: + // 1. Vars of all kinds + // 2. Op nodes (nothing we can do) + // 3. Scalar constants (not much benefit from binding to a var) + if (!(e->IsInstance() || e->IsInstance() || + e->IsInstance() || e->IsInstance() || + (e.as() && (e.as()->is_scalar())))) { + int count = 0; + if (count_map_.count(e)) { + count = count_map_.at(e); + } + count_map_[e] = count + 1; + } + ExprVisitor::VisitExpr(e); } - void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) final { - auto post = VisitExpr(GetRef(val)); - return ExprMutator::VisitBinding_(binding, val); - } + // do not visit inner functions: we will do CSE within those + void VisitExpr_(const FunctionNode* func) override {} + + // we are not going to do replacements inside struct info to avoid binding lots of reused shapes + void VisitExprDepStructInfoField(const StructInfo& struct_info) override {} - void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* val) final { - auto post = VisitExpr(GetRef(val)); - auto new_val = Rewrite_(val, post); - return ExprMutator::VisitBinding_(binding, new_val.as()); + std::unordered_map Count( + const DataflowBlock& df_block) { + for (auto binding : df_block->bindings) { + VisitBinding(binding); + } + return count_map_; } private: - Expr Rewrite_(const CallNode* call, const Expr& post) { - Expr new_expr = post; - const CallNode* new_call = new_expr.as(); - ICHECK(new_call); - const OpNode* op = new_call->op.as(); - StructuralEqual attrs_equal; - - if (new_call->args.size() == 0 || op == nullptr) { - return new_expr; - } - if (fskip_ != nullptr && fskip_(new_expr)) { - return new_expr; - } + std::unordered_map count_map_; +}; - auto it = expr_map_.find(new_call->op); - if (it != expr_map_.end()) { - for (const Expr& candidate_expr : it->second) { - if (const CallNode* candidate = candidate_expr.as()) { - bool is_equivalent = true; - if (!attrs_equal(new_call->attrs, candidate->attrs)) { - continue; - } - for (size_t i = 0; i < new_call->args.size(); i++) { - if (!IsEquivalent(new_call->args[i], candidate->args[i])) { - is_equivalent = false; - break; - } - } - if (!is_equivalent) continue; - return GetRef(candidate); - } +// forward declaration +DataflowBlock EliminateCommonSubexpr(const DataflowBlock&); + +class CommonSubexprEliminator : public ExprMutator { + public: + explicit CommonSubexprEliminator( + const std::unordered_map& count_map) + : count_map_(count_map) {} + + // overriding here ensures we visit every subexpression + Expr VisitExpr(const Expr& e) override { + if (count_map_.count(e) && count_map_.at(e) > 1) { + // if we already have a mapping for it, get it + if (replacements_.count(e)) { + return replacements_.at(e); } + // Otherwise, insert a new binding for the current expression. + // Visit before emitting to do inner replacements + Expr new_e = ExprMutator::VisitExpr(e); + Var v = builder_->Emit(new_e); + replacements_[e] = v; + return v; } - expr_map_[new_call->op].push_back(new_expr); - return new_expr; + return ExprMutator::VisitExpr(e); } - Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) { - Expr new_expr = post; - const TupleGetItemNode* new_tuple_item = new_expr.as(); - ICHECK(new_tuple_item); + // we are not going to do replacements inside struct info to avoid binding lots of reused shapes + StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override { + return struct_info; + } - if (fskip_ != nullptr && fskip_(new_expr)) { - return new_expr; + Expr VisitExpr_(const FunctionNode* func) override { + // for an inner function, we will do CSE on its body + Expr new_body = ExprMutator::VisitExpr(func->body); + if (new_body.same_as(func->body)) { + return GetRef(func); } + return Function(func->params, new_body, func->ret_struct_info, func->attrs, func->span); + } - auto it = expr_map_.find(new_tuple_item->tuple); - if (it != expr_map_.end()) { - for (const Expr& candidate_expr : it->second) { - if (const TupleGetItemNode* candidate = candidate_expr.as()) { - if (new_tuple_item->index == candidate->index) { - return GetRef(candidate); - } + // this should happen only for the inner function case + Expr VisitExpr_(const SeqExprNode* seq) override { + bool all_unchanged = true; + Array new_blocks; + // apply CSE within dataflow blocks only + for (auto block : seq->blocks) { + if (const DataflowBlockNode* df_block = block.as()) { + auto new_df_block = EliminateCommonSubexpr(GetRef(df_block)); + if (!new_df_block.same_as(block)) { + new_blocks.push_back(new_df_block); + all_unchanged = false; + continue; } } + new_blocks.push_back(block); + } + + if (all_unchanged) { + return GetRef(seq); } - expr_map_[new_tuple_item->tuple].push_back(new_expr); - return new_expr; + // do not visit the body + return SeqExpr(new_blocks, seq->body, seq->span); } - bool IsEquivalent(const Expr& arg, const Expr& candidate_arg) { - if (arg->IsInstance() && candidate_arg->IsInstance()) { - const TupleNode* arg_node = arg.as(); - const TupleNode* candidate_arg_node = candidate_arg.as(); + void VisitBinding_(const VarBindingNode* binding) override { + // no need to visit var def because the struct info isn't going to change + Expr new_value = RegisterBoundValue(binding->var, binding->value); - if (arg_node->fields.size() != candidate_arg_node->fields.size()) { - return false; - } + if (new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + // no need to renormalize new_value because all replacements are with vars + builder_->EmitNormalized(VarBinding(binding->var, new_value, binding->span)); + } + } - for (size_t i = 0; i < arg_node->fields.size(); i++) { - if (!arg_node->fields[i].same_as(candidate_arg_node->fields[i]) && - !IsEqualScalar(arg_node->fields[i], candidate_arg_node->fields[i])) { - return false; - } - } + void VisitBinding_(const MatchCastNode* binding) override { + // no need to visit var def because the struct info isn't going to change + Expr new_value = RegisterBoundValue(binding->var, binding->value); + + // re-emit old binding if nothing changes + if (new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); } else { - if (!arg.same_as(candidate_arg) && !IsEqualScalar(arg, candidate_arg)) { - return false; - } + // no need to renormalize new_value because all replacements are with vars + builder_->EmitNormalized( + MatchCast(binding->var, new_value, binding->struct_info, binding->span)); } + } - return true; + private: + Expr RegisterBoundValue(Var var, Expr bound_value) { + // special case: if we are processing a binding + // and this is the first time we've encountered it, + // we will use the binding's var for the mapping + bool newly_replaced = false; + if (count_map_.count(bound_value) && count_map_.at(bound_value) > 1 && + !replacements_.count(bound_value)) { + replacements_[bound_value] = var; + newly_replaced = true; + } + + if (newly_replaced) { + // If we've just added the mapping, using the overridden visitor will + // just return the var, which we don't want, so we will use + // the superclass VisitExpr to do inner substitutions + return ExprMutator::VisitExpr(bound_value); + } + return VisitExpr(bound_value); } - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> expr_map_; - runtime::TypedPackedFunc fskip_; + + const std::unordered_map& count_map_; + std::unordered_map replacements_; }; -DataflowBlock EliminateCommonSubexpr(const DataflowBlock& df_block, PackedFunc fskip) { - CommonSubexprEliminator mutator(fskip); - return Downcast(mutator.VisitBindingBlock(df_block)); +DataflowBlock EliminateCommonSubexpr(const DataflowBlock& df_block) { + SubexprCounter counter; + auto count_map = counter.Count(df_block); + CommonSubexprEliminator eliminator(count_map); + return Downcast(eliminator.VisitBindingBlock(df_block)); } namespace transform { -Pass EliminateCommonSubexpr(runtime::TypedPackedFunc fskip) { +Pass EliminateCommonSubexpr() { runtime::TypedPackedFunc pass_func = [=](DataflowBlock df_block, IRModule m, PassContext pc) { - return Downcast(EliminateCommonSubexpr(df_block, fskip)); + return Downcast(EliminateCommonSubexpr(df_block)); }; return CreateDataflowBlockPass(pass_func, 1, "EliminateCommonSubexpr", {}); } diff --git a/tests/python/relax/test_transform_cse.py b/tests/python/relax/test_transform_cse.py index 05904ae79f25..4ee9653ead39 100644 --- a/tests/python/relax/test_transform_cse.py +++ b/tests/python/relax/test_transform_cse.py @@ -20,6 +20,8 @@ from tvm.relax.transform import EliminateCommonSubexpr from tvm.script.parser import ir as I, relax as R, tir as T +import numpy as np + def verify(input, expected): tvm.ir.assert_structural_equal(EliminateCommonSubexpr()(input), expected) @@ -43,23 +45,141 @@ class Expected: def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")): with R.dataflow(): lv0 = R.add(x, y) - gv = R.multiply(lv0, lv0) + # can combine with canonicalizing bindings + # and getting rid of unused bindings to eliminate this line too + lv1 = lv0 + gv = R.multiply(lv0, lv1) + R.output(gv) + return gv + + verify(Before, Expected) + + +def test_constants(): + @I.ir_module + class Before: + @R.function + def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")): + with R.dataflow(): + # we are not going to bind the constant 1 to a var + lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32")) + # we expect to bind the repeated large constants + lv1 = R.add( + R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))), + R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))), + ) + gv = (lv0, lv1) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")): + with R.dataflow(): + lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32")) + lv1 = R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))) + lv2 = R.add(lv1, lv1) + gv = (lv0, lv2) + R.output(gv) + return gv + + verify(Before, Expected) + + +def test_repeated_inner_tuples(): + @I.ir_module + class Before: + @R.function + def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + with R.dataflow(): + # repeated units: (x, x), (x, (x, x)), ((x, x), (x, (x, x))) + tup = (((x, x), (x, (x, x))), ((x, x), (x, (x, x))), (x, (x, x))) + gv = tup[0][0][1] + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + with R.dataflow(): + t1 = (x, x) + t2 = (x, t1) + t3 = (t1, t2) + t4 = (t3, t3, t2) + gv = t4[0][0][1] R.output(gv) return gv verify(Before, Expected) -def test_skip_callback(): - pass +def test_inner_function(): + @I.ir_module + class Before: + @R.function + def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + with R.dataflow(): + # we are going to do CSE inside the local function + @R.function + def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + # not in dataflow: should not be touched + z = R.add(R.add(y, y), R.add(y, y)) + with R.dataflow(): + # writing this out in ANF to illustrate why CSE behaves as it does + # result of ANF transforming R.add(R.add(y, y), R.add(y, y)) + lv0 = R.add(y, y) + lv1 = R.add(y, y) + lv2 = R.add(lv0, lv1) + gv = lv2 + R.output(gv) + return R.add(z, gv) + + # also making the ANF explicit to better illustrate the result of CSE + # result of ANF transforming R.add(R.add(bar(x), bar(x)), R.add(bar(x), bar(x))) + lv0 = bar(x) + lv1 = bar(x) + lv2 = R.add(lv0, lv1) + lv3 = bar(x) + lv4 = bar(x) + lv5 = R.add(lv3, lv4) + lv6 = R.add(lv2, lv5) + gv = lv6 + R.output(gv) + return gv + @I.ir_module + class Expected: + @R.function + def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + with R.dataflow(): -def test_tuple_get_time(): - pass + @R.function + def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + z = R.add(R.add(y, y), R.add(y, y)) + with R.dataflow(): + lv0 = R.add(y, y) + lv1 = lv0 + lv2 = R.add(lv0, lv1) + gv = lv2 + R.output(gv) + return R.add(z, gv) + # can further clean this up + # using canonicalize bindings, eliminate unused bindings, and CSE again + lv0 = bar(x) + lv1 = lv0 + lv2 = R.add(lv0, lv1) + lv3 = lv0 + lv4 = lv0 + lv5 = R.add(lv3, lv4) + lv6 = R.add(lv2, lv5) + gv = lv6 + R.output(gv) + return gv -def test_tuple_arg(): - pass + verify(Before, Expected) if __name__ == "__main__": From 1e8232d89bd2c08cbde58ce78cee2a4f920ffb48 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 21 Mar 2023 20:03:26 -0400 Subject: [PATCH 3/3] Missing trailing newline --- src/relax/transform/eliminate_common_subexpr.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc index f5f932c0b603..9c9252ddfa72 100644 --- a/src/relax/transform/eliminate_common_subexpr.cc +++ b/src/relax/transform/eliminate_common_subexpr.cc @@ -206,4 +206,4 @@ TVM_REGISTER_GLOBAL("relax.transform.EliminateCommonSubexpr") } // namespace transform } // namespace relax -} // namespace tvm \ No newline at end of file +} // namespace tvm