diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 5a21f76b0b4e..ac481d5006ba 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -137,6 +137,15 @@ TVM_DLL Pass Normalize(); */ TVM_DLL Pass CanonicalizeBindings(); +/*! + * Eliminate common subexpressions within dataflow blocks. + * \return The pass that eliminates common subexpressions. + * + * \note For functions local to dataflow blocks, this pass performs + * CSE *within* those functions. + */ +TVM_DLL Pass EliminateCommonSubexpr(); + /*! * \brief Bind params of function of the module to constant tensors. * diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 95f81f7e6cde..f243dc16cf77 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -100,6 +100,20 @@ def CanonicalizeBindings() -> tvm.ir.transform.Pass: return _ffi_api.CanonicalizeBindings() # type: ignore +def EliminateCommonSubexpr() -> DataflowBlockPass: + """Eliminate common subexpressions within dataflow blocks. + + Note: For functions local to dataflow blocks, this pass performs + CSE *within* those functions + + Returns + ------- + ret : tvm.transform.Pass + The registered pass that eliminates common subexpressions. + """ + return _ffi_api.EliminateCommonSubexpr() # type: ignore + + def RewriteDataflowReshape() -> tvm.ir.transform.Pass: """Convert all reshape-like call_tir to VM reshape operator call. The VM reshape operator calls will be further lowered to a CreateView diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc new file mode 100644 index 000000000000..9c9252ddfa72 --- /dev/null +++ b/src/relax/transform/eliminate_common_subexpr.cc @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/eliminate_common_subexpr.cc + * \brief Eliminrate common subexpression pass. + * + * Currently it removes common subexpressions within a DataflowBlock. + */ +#include +#include + +namespace tvm { +namespace relax { + +class SubexprCounter : public ExprVisitor { + public: + // overriding VisitExpr ensures we do this for every subexpression + void VisitExpr(const Expr& e) override { + // Cases we ignore because we will not substitute them: + // 1. Vars of all kinds + // 2. Op nodes (nothing we can do) + // 3. Scalar constants (not much benefit from binding to a var) + if (!(e->IsInstance() || e->IsInstance() || + e->IsInstance() || e->IsInstance() || + (e.as() && (e.as()->is_scalar())))) { + int count = 0; + if (count_map_.count(e)) { + count = count_map_.at(e); + } + count_map_[e] = count + 1; + } + ExprVisitor::VisitExpr(e); + } + + // do not visit inner functions: we will do CSE within those + void VisitExpr_(const FunctionNode* func) override {} + + // we are not going to do replacements inside struct info to avoid binding lots of reused shapes + void VisitExprDepStructInfoField(const StructInfo& struct_info) override {} + + std::unordered_map Count( + const DataflowBlock& df_block) { + for (auto binding : df_block->bindings) { + VisitBinding(binding); + } + return count_map_; + } + + private: + std::unordered_map count_map_; +}; + +// forward declaration +DataflowBlock EliminateCommonSubexpr(const DataflowBlock&); + +class CommonSubexprEliminator : public ExprMutator { + public: + explicit CommonSubexprEliminator( + const std::unordered_map& count_map) + : count_map_(count_map) {} + + // overriding here ensures we visit every subexpression + Expr VisitExpr(const Expr& e) override { + if (count_map_.count(e) && count_map_.at(e) > 1) { + // if we already have a mapping for it, get it + if (replacements_.count(e)) { + return replacements_.at(e); + } + // Otherwise, insert a new binding for the current expression. + // Visit before emitting to do inner replacements + Expr new_e = ExprMutator::VisitExpr(e); + Var v = builder_->Emit(new_e); + replacements_[e] = v; + return v; + } + return ExprMutator::VisitExpr(e); + } + + // we are not going to do replacements inside struct info to avoid binding lots of reused shapes + StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override { + return struct_info; + } + + Expr VisitExpr_(const FunctionNode* func) override { + // for an inner function, we will do CSE on its body + Expr new_body = ExprMutator::VisitExpr(func->body); + if (new_body.same_as(func->body)) { + return GetRef(func); + } + return Function(func->params, new_body, func->ret_struct_info, func->attrs, func->span); + } + + // this should happen only for the inner function case + Expr VisitExpr_(const SeqExprNode* seq) override { + bool all_unchanged = true; + Array new_blocks; + // apply CSE within dataflow blocks only + for (auto block : seq->blocks) { + if (const DataflowBlockNode* df_block = block.as()) { + auto new_df_block = EliminateCommonSubexpr(GetRef(df_block)); + if (!new_df_block.same_as(block)) { + new_blocks.push_back(new_df_block); + all_unchanged = false; + continue; + } + } + new_blocks.push_back(block); + } + + if (all_unchanged) { + return GetRef(seq); + } + // do not visit the body + return SeqExpr(new_blocks, seq->body, seq->span); + } + + void VisitBinding_(const VarBindingNode* binding) override { + // no need to visit var def because the struct info isn't going to change + Expr new_value = RegisterBoundValue(binding->var, binding->value); + + if (new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + // no need to renormalize new_value because all replacements are with vars + builder_->EmitNormalized(VarBinding(binding->var, new_value, binding->span)); + } + } + + void VisitBinding_(const MatchCastNode* binding) override { + // no need to visit var def because the struct info isn't going to change + Expr new_value = RegisterBoundValue(binding->var, binding->value); + + // re-emit old binding if nothing changes + if (new_value.same_as(binding->value)) { + builder_->EmitNormalized(GetRef(binding)); + } else { + // no need to renormalize new_value because all replacements are with vars + builder_->EmitNormalized( + MatchCast(binding->var, new_value, binding->struct_info, binding->span)); + } + } + + private: + Expr RegisterBoundValue(Var var, Expr bound_value) { + // special case: if we are processing a binding + // and this is the first time we've encountered it, + // we will use the binding's var for the mapping + bool newly_replaced = false; + if (count_map_.count(bound_value) && count_map_.at(bound_value) > 1 && + !replacements_.count(bound_value)) { + replacements_[bound_value] = var; + newly_replaced = true; + } + + if (newly_replaced) { + // If we've just added the mapping, using the overridden visitor will + // just return the var, which we don't want, so we will use + // the superclass VisitExpr to do inner substitutions + return ExprMutator::VisitExpr(bound_value); + } + return VisitExpr(bound_value); + } + + const std::unordered_map& count_map_; + std::unordered_map replacements_; +}; + +DataflowBlock EliminateCommonSubexpr(const DataflowBlock& df_block) { + SubexprCounter counter; + auto count_map = counter.Count(df_block); + CommonSubexprEliminator eliminator(count_map); + return Downcast(eliminator.VisitBindingBlock(df_block)); +} + +namespace transform { + +Pass EliminateCommonSubexpr() { + runtime::TypedPackedFunc pass_func = + [=](DataflowBlock df_block, IRModule m, PassContext pc) { + return Downcast(EliminateCommonSubexpr(df_block)); + }; + return CreateDataflowBlockPass(pass_func, 1, "EliminateCommonSubexpr", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.EliminateCommonSubexpr") + .set_body_typed(EliminateCommonSubexpr); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_cse.py b/tests/python/relax/test_transform_cse.py new file mode 100644 index 000000000000..4ee9653ead39 --- /dev/null +++ b/tests/python/relax/test_transform_cse.py @@ -0,0 +1,186 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test eliminate common subexpr pass""" +import tvm +import tvm.testing +from tvm.relax.transform import EliminateCommonSubexpr +from tvm.script.parser import ir as I, relax as R, tir as T + +import numpy as np + + +def verify(input, expected): + tvm.ir.assert_structural_equal(EliminateCommonSubexpr()(input), expected) + + +def test_simple(): + @I.ir_module + class Before: + @R.function + def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")): + with R.dataflow(): + lv0 = R.add(x, y) + lv1 = R.add(x, y) + gv = R.multiply(lv0, lv1) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")): + with R.dataflow(): + lv0 = R.add(x, y) + # can combine with canonicalizing bindings + # and getting rid of unused bindings to eliminate this line too + lv1 = lv0 + gv = R.multiply(lv0, lv1) + R.output(gv) + return gv + + verify(Before, Expected) + + +def test_constants(): + @I.ir_module + class Before: + @R.function + def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")): + with R.dataflow(): + # we are not going to bind the constant 1 to a var + lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32")) + # we expect to bind the repeated large constants + lv1 = R.add( + R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))), + R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))), + ) + gv = (lv0, lv1) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")): + with R.dataflow(): + lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32")) + lv1 = R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))) + lv2 = R.add(lv1, lv1) + gv = (lv0, lv2) + R.output(gv) + return gv + + verify(Before, Expected) + + +def test_repeated_inner_tuples(): + @I.ir_module + class Before: + @R.function + def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + with R.dataflow(): + # repeated units: (x, x), (x, (x, x)), ((x, x), (x, (x, x))) + tup = (((x, x), (x, (x, x))), ((x, x), (x, (x, x))), (x, (x, x))) + gv = tup[0][0][1] + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + with R.dataflow(): + t1 = (x, x) + t2 = (x, t1) + t3 = (t1, t2) + t4 = (t3, t3, t2) + gv = t4[0][0][1] + R.output(gv) + return gv + + verify(Before, Expected) + + +def test_inner_function(): + @I.ir_module + class Before: + @R.function + def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + with R.dataflow(): + # we are going to do CSE inside the local function + @R.function + def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + # not in dataflow: should not be touched + z = R.add(R.add(y, y), R.add(y, y)) + with R.dataflow(): + # writing this out in ANF to illustrate why CSE behaves as it does + # result of ANF transforming R.add(R.add(y, y), R.add(y, y)) + lv0 = R.add(y, y) + lv1 = R.add(y, y) + lv2 = R.add(lv0, lv1) + gv = lv2 + R.output(gv) + return R.add(z, gv) + + # also making the ANF explicit to better illustrate the result of CSE + # result of ANF transforming R.add(R.add(bar(x), bar(x)), R.add(bar(x), bar(x))) + lv0 = bar(x) + lv1 = bar(x) + lv2 = R.add(lv0, lv1) + lv3 = bar(x) + lv4 = bar(x) + lv5 = R.add(lv3, lv4) + lv6 = R.add(lv2, lv5) + gv = lv6 + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + with R.dataflow(): + + @R.function + def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + z = R.add(R.add(y, y), R.add(y, y)) + with R.dataflow(): + lv0 = R.add(y, y) + lv1 = lv0 + lv2 = R.add(lv0, lv1) + gv = lv2 + R.output(gv) + return R.add(z, gv) + + # can further clean this up + # using canonicalize bindings, eliminate unused bindings, and CSE again + lv0 = bar(x) + lv1 = lv0 + lv2 = R.add(lv0, lv1) + lv3 = lv0 + lv4 = lv0 + lv5 = R.add(lv3, lv4) + lv6 = R.add(lv2, lv5) + gv = lv6 + R.output(gv) + return gv + + verify(Before, Expected) + + +if __name__ == "__main__": + tvm.testing.main()