diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index f515ba620196..82fb73b1bd0d 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -399,18 +399,25 @@ TVM_DLL Map> DataflowBlockUseDef(const DataflowBlock& dfb); /*! * \brief Get the use-def chain of variables inside a function. * - * \param fn The function to be analyzed. - * \return A map from variable definitions to a set of uses and variables needed by return value. + * \param expr The expression to be analyzed. + * + * \return A tuple of variable usage and variable outputs. The first + * element is a map from variable definitions to the set of downstream + * users of that definition. The second element is a list of + * variables whose usage occurs outside of any variable binding, + * typically the output body of a relax::Function or a relax::SeqExpr. */ -std::pair>, Array> FunctionUseDef(const Function& fn); +std::pair>, Array> FunctionUseDef(const Expr& expr); /*! * \brief Remove unused statements inside DataflowBlocks. * - * \param fn The function to remove unused statements. - * \return The function that contains no unused statements in DataflowBlock. + * \param expr The expression (typically a relax::Function) from which + * to remove unused statements. + * + * \return The updated function with no unused statements in DataflowBlock. */ -TVM_DLL Function RemoveAllUnused(const Function fn); +TVM_DLL Expr RemoveAllUnused(Expr expr); /*! * \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps. diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc index 1c49fd581f7d..7355612b6479 100644 --- a/src/relax/analysis/udchain.cc +++ b/src/relax/analysis/udchain.cc @@ -40,14 +40,15 @@ class UDChain : public relax::ExprVisitor { // nullptr users means it is the output of the function. std::map> to_users; - const VarNode* cur_user_; + const VarNode* cur_user_{nullptr}; void VisitBinding_(const VarBindingNode* binding) override { // init + auto cache = cur_user_; cur_user_ = binding->var.get(); this->VisitVarDef(binding->var); this->VisitExpr(binding->value); - cur_user_ = nullptr; + cur_user_ = cache; } void VisitExpr_(const VarNode* op) override { to_users[op].insert(cur_user_); } @@ -63,7 +64,7 @@ class UDChain : public relax::ExprVisitor { }; std::pair>, runtime::Array> FunctionUseDef( - const Function& fn) { + const Expr& fn) { UDChain udchain; udchain.VisitExpr(fn); diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc index 4eec0310fbab..2a9537851430 100644 --- a/src/relax/ir/binding_rewrite.cc +++ b/src/relax/ir/binding_rewrite.cc @@ -245,38 +245,31 @@ class RemoveUnusedVars : public ExprMutator { RemoveUnusedVars(Map> users, Array fn_outputs) : RemoveUnusedVars(GetUnusedVars(users, fn_outputs)) {} - BindingBlock VisitBindingBlock_(const BindingBlockNode* block) override { - builder_->BeginBindingBlock(); - for (Binding binding : block->bindings) { - bool can_remove = [&]() -> bool { - if (!unused_vars.count(binding->var)) { - return false; - } - auto var_binding = binding.as(); - if (!var_binding) { - return false; - } - return var_binding->value->IsInstance(); - }(); - if (!can_remove) { - VisitBinding(binding); - } + void VisitBinding_(const VarBindingNode* binding) override { + bool can_remove = unused_vars.count(binding->var) && + (in_dataflow_block_ || !ContainsImpureCall(binding->value)); + if (!can_remove) { + ExprMutator::VisitBinding_(binding); } - return builder_->EndBlock(); } BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) override { - auto prev_dfb = GetRef(block); - builder_->BeginDataflowBlock(); - for (Binding binding : block->bindings) { - if (!unused_vars.count(binding->var) || binding.as()) { - VisitBinding(binding); - } + bool capture_output = (block == caught_rewrite.get()); + + bool cache = in_dataflow_block_; + in_dataflow_block_ = true; + BindingBlock output = ExprMutator::VisitBindingBlock_(block); + in_dataflow_block_ = cache; + + if (capture_output) { + caught_rewrite = Downcast(output); } - auto new_dfb = builder_->EndBlock(); - if (caught_rewrite == prev_dfb) caught_rewrite = Downcast(new_dfb); - return std::move(new_dfb); + + return std::move(output); } + + private: + bool in_dataflow_block_{false}; }; void DataflowBlockRewriteNode::RemoveUnused(Var unused, bool allow_undef) { @@ -327,10 +320,10 @@ void DataflowBlockRewriteNode::RemoveAllUnused() { TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_all_unused") .set_body_typed([](DataflowBlockRewrite rwt) { rwt->RemoveAllUnused(); }); -Function RemoveAllUnused(Function fn) { - auto [users, outputs] = FunctionUseDef(fn); +Expr RemoveAllUnused(Expr expr) { + auto [users, outputs] = FunctionUseDef(expr); RemoveUnusedVars remover(users, outputs); - return Downcast(remover.VisitExpr_(fn.get())); + return remover.VisitExpr(std::move(expr)); } TVM_REGISTER_GLOBAL("relax.analysis.remove_all_unused").set_body_typed(RemoveAllUnused); diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index ab2ad4fa3693..b1cc04b3acfa 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -938,7 +938,7 @@ class PatternRewriter : ExprMutator { params.insert(p.get()); } PatternRewriter rewriter(pat, rewriter_func, params); - return RemoveAllUnused(Downcast(rewriter.VisitExpr(f))); + return Downcast(RemoveAllUnused(rewriter.VisitExpr(f))); } void VisitBinding_(const VarBindingNode* binding) final { diff --git a/src/relax/transform/dead_code_elimination.cc b/src/relax/transform/dead_code_elimination.cc index 494665ec712a..6d9f25296a5a 100644 --- a/src/relax/transform/dead_code_elimination.cc +++ b/src/relax/transform/dead_code_elimination.cc @@ -137,7 +137,7 @@ IRModule DeadCodeElimination(const IRModule& arg_mod, Array ent IRModule updates; for (const auto& [gvar, base_func] : mod->functions) { if (auto opt = base_func.as()) { - auto new_func = RemoveAllUnused(opt.value()); + auto new_func = Downcast(RemoveAllUnused(opt.value())); if (!new_func.same_as(base_func)) { updates->Add(gvar, new_func); } diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index a13b7f3d9368..a07bfeb89d79 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -34,7 +34,7 @@ class ConstantFolder : public ExprMutator { public: static Function Fold(Function func, IRModule ctx_module) { ConstantFolder folder(std::move(ctx_module)); - func = RemoveAllUnused(Downcast(folder(func))); + func = Downcast(RemoveAllUnused(folder(func))); return func; } diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index 4fd18386703a..024e40c9ce57 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -635,7 +635,7 @@ class GradientMutator : private ExprMutator { new_func = CallTIRWithGradEliminator::Transform(new_func); if (remove_all_unused) { - new_func = RemoveAllUnused(new_func); + new_func = Downcast(RemoveAllUnused(new_func)); } // Step 5.3 mark the transformed function as public diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index d5545a0a5623..dfee0262068a 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -88,6 +88,13 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: def test_binding_block_remove_all_unused(): + """Remove unused dataflow bindings + + Removal of unused bindings may not remove side effects. Since + bindings within a dataflow block are guaranteed not to have side + effects, they may be removed if unused. + """ + @tvm.script.ir_module class IdentityUnused: @R.function @@ -117,24 +124,49 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) -def test_binding_block_remove_all_unused_without_dataflow(): - @tvm.script.ir_module - class IdentityUnused: - @R.function - def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: - lv0 = x - unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) - unused1 = R.call_dps_packed( - "my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32") - ) - z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) - return z +def test_binding_block_remove_unused_pure_without_dataflow(): + """Remove unused dataflow bindings - optimized = remove_all_unused(IdentityUnused["main"]) + Removal of unused bindings may not remove side effects. Unused + bindings whose value is a pure operation + (e.g. `R.call_dps_packed`) may be removed, even if outside of a + dataflow block. + """ - GroundTruth = IdentityUnused + @R.function(private=True) + def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + lv0 = x + unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_dps_packed("my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32")) + return x - tvm.ir.assert_structural_equal(optimized, GroundTruth["main"]) + @R.function(private=True) + def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + return x + + after = remove_all_unused(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_binding_block_keep_impure_without_dataflow(): + """Remove unused dataflow bindings + + Removal of unused bindings may not remove side effects. Unused + bindings whose value is an impure operation (e.g. `R.call_packed`) + may not be removed, as outside of a dataflow block they may + contain side effects. + """ + + @R.function(private=True) + def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + lv0 = x + y = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) + return y + + expected = before + + after = remove_all_unused(before) + tvm.ir.assert_structural_equal(expected, after) def test_binding_block_remove_all_unused_func_without_dataflow(): @@ -226,6 +258,70 @@ def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor(dtype="int32", ndim=3): tvm.ir.assert_structural_equal(optimized, IdentityUnused["main"]) +def test_remove_all_unused_from_dataflow_block(): + """Like test_chained_remove_all_unused, but on a SeqExpr""" + + @R.function + def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_dps_packed( + "my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32") + ) + R.output(lv0) + return lv0 + + @R.function + def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + with R.dataflow(): + lv0 = x + R.output(lv0) + return lv0 + + after = remove_all_unused(before.body) + tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True) + + +def test_remove_all_unused_from_binding_block(): + """Like test_chained_remove_all_unused, but on a SeqExpr""" + + @R.function + def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + lv0 = x + unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_dps_packed("my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32")) + return lv0 + + @R.function + def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + lv0 = x + return lv0 + + after = remove_all_unused(before.body) + tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True) + + +def test_retain_impure_calls_unused_in_binding_block(): + """An impure call may have side effects, and must be kept""" + + @R.function + def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + lv0 = x + unused0 = R.call_packed("my_impure_call", x, sinfo_args=R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_dps_packed("my_unused_call", (lv0,), R.Tensor((32, 32), dtype="float32")) + return lv0 + + @R.function + def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + lv0 = x + unused0 = R.call_packed("my_impure_call", x, sinfo_args=R.Tensor((32, 32), dtype="float32")) + return lv0 + + after = remove_all_unused(before.body) + tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True) + + def test_name_to_binding_var_shadowing(): @R.function def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: diff --git a/tests/python/relax/test_transform_fold_constant.py b/tests/python/relax/test_transform_fold_constant.py index c2a3bd50922b..a4dffba11443 100644 --- a/tests/python/relax/test_transform_fold_constant.py +++ b/tests/python/relax/test_transform_fold_constant.py @@ -73,7 +73,6 @@ def before(c0: R.Tensor((16, 16), "float32")): @R.function def expected(c1: R.Tensor((16, 16), "float32")): - lv0 = c1 return c1 c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) @@ -104,7 +103,6 @@ def before(c0: R.Tensor((2, 3), "float32")): @R.function def expected(c1: R.Tensor((3, 2), "float32")): - lv0 = c1 return c1 c0_np = np.arange(2 * 3).astype("float32").reshape(2, 3) @@ -135,8 +133,6 @@ def before(c0: R.Tensor((2, 2), "float32")): @R.function def expected(c1: R.Tensor((2, 2), "float32"), c2: R.Tensor((2, 2), "float32")): - lv0 = c1 - lv1 = c2 return c2 c0_np = np.arange((2 * 2)).astype("float32").reshape(2, 2) @@ -218,7 +214,7 @@ def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor("float32", ndim=2)): lv2 = relax.call_tir(cls.sub, (c0, lv1), R.Tensor((16, 16), dtype="float32")) # this line can not be folded because x's shape is unknown lv3 = relax.call_tir(cls.sub, (lv2, x), R.Tensor((16, 16), dtype="float32")) - return lv3 + return (lv0, lv3) @R.function def expected( @@ -226,19 +222,15 @@ def expected( c1: R.Tensor((16, 16), "float32"), c2: R.Tensor((16, 16), "float32"), x: R.Tensor("float32", ndim=2), - ) -> R.Tensor: + ): n, m = T.int64(), T.int64() cls = Module x0 = R.match_cast(x, R.Tensor((n, m), "float32")) # this line cannot be folded because n is unknown lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((n, 16), dtype="float32")) - # this line can be folded - lv1 = c1 - # this line can be folded because all inputs are const - lv2 = c2 # this line can not be folded because x's shape is unknown lv3 = relax.call_tir(cls.sub, (c2, x), R.Tensor((16, 16), dtype="float32")) - return lv3 + return (lv0, lv3) c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16) c1_np = c0_np + 1 @@ -268,7 +260,6 @@ def before(c0: R.Tensor((16, 16), "int32")): @R.function def expected(c1: R.Tensor((16, 16), "int32")): - lv0 = c1 return c1 c0_np = np.arange((16 * 16)).astype("int32").reshape(16, 16) diff --git a/tests/python/relax/test_tuning_api.py b/tests/python/relax/test_tuning_api.py index 5c2f165dc31d..082c9ce16a30 100644 --- a/tests/python/relax/test_tuning_api.py +++ b/tests/python/relax/test_tuning_api.py @@ -64,7 +64,6 @@ def before(c0: R.Tensor((16, 16), "int32")): # Expected IRModule after transformation. @R.function def expected(c1: R.Tensor((16, 16), "int32")): - lv0 = c1 return c1