From 18b6aa0a49c68d17b39b2921b17f9157cf3da75f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 12 Sep 2023 16:04:00 +0000 Subject: [PATCH] [Unity] Extend EliminateCommonSubexpr to operate on relax::Expr Prior to this commit, the `EliminateCommonSubexpr` utility could only apply to `relax::Function` instances. This commit extends the allowed usage to apply to any `relax::Expr` that contains variable bindings. This is only included as an internal utility within the C++ implementation, and is not currently exposed for external use. --- .../transform/eliminate_common_subexpr.cc | 43 ++++++++++--------- src/relax/transform/utils.h | 14 ++++++ 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/src/relax/transform/eliminate_common_subexpr.cc b/src/relax/transform/eliminate_common_subexpr.cc index 8bbb05f32797..fa90d4193337 100644 --- a/src/relax/transform/eliminate_common_subexpr.cc +++ b/src/relax/transform/eliminate_common_subexpr.cc @@ -28,6 +28,8 @@ #include #include +#include "utils.h" + namespace tvm { namespace relax { @@ -74,6 +76,12 @@ class ImpurityDetector : public ExprVisitor { class SubexprCounter : public ExprVisitor { public: + static std::unordered_map Count(const Expr& expr) { + SubexprCounter visitor; + visitor(expr); + return visitor.count_map_; + } + // overriding VisitExpr ensures we do this for every subexpression void VisitExpr(const Expr& e) override { // Cases we ignore because we will not substitute them: @@ -106,25 +114,17 @@ class SubexprCounter : public ExprVisitor { // we are not going to do replacements inside struct info to avoid binding lots of reused shapes void VisitExprDepStructInfoField(const StructInfo& struct_info) override {} - std::unordered_map Count(const Function& func) { - VisitExpr(func->body); - return count_map_; - } - private: std::unordered_map count_map_; ImpurityDetector impurity_detector_; }; -// forward declaration -Function EliminateCommonSubexpr(const Function&, bool call_only); - class CommonSubexprEliminator : public ExprMutator { public: explicit CommonSubexprEliminator( - const std::unordered_map& count_map, + std::unordered_map count_map, bool call_only = false) - : count_map_(count_map), call_only_(call_only) {} + : count_map_(std::move(count_map)), call_only_(call_only) {} // overriding here ensures we visit every subexpression Expr VisitExpr(const Expr& e) override { @@ -151,9 +151,15 @@ class CommonSubexprEliminator : public ExprMutator { return struct_info; } - Expr VisitExpr_(const FunctionNode* func) override { - // do full CSE within the function - return EliminateCommonSubexpr(GetRef(func), call_only_); + Expr VisitExpr_(const FunctionNode* op) override { + Function func = GetRef(op); + + auto cache = SubexprCounter::Count(op->body); + std::swap(cache, count_map_); + Expr output = ExprMutator::VisitExpr_(op); + std::swap(cache, count_map_); + + return output; } void VisitBinding_(const VarBindingNode* binding) override { @@ -203,17 +209,14 @@ class CommonSubexprEliminator : public ExprMutator { return VisitExpr(bound_value); } - const std::unordered_map& count_map_; + std::unordered_map count_map_; std::unordered_map replacements_; bool call_only_{false}; }; -Function EliminateCommonSubexpr(const Function& func, bool call_only) { - SubexprCounter counter; - auto count_map = counter.Count(func); - CommonSubexprEliminator eliminator(count_map, call_only); - return Function(func->params, eliminator.VisitExpr(func->body), func->ret_struct_info, - func->is_pure, func->attrs, func->span); +Expr EliminateCommonSubexpr(const Expr& expr, bool call_only) { + CommonSubexprEliminator mutator(SubexprCounter::Count(expr), call_only); + return mutator(expr); } namespace transform { diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index 3d40a565bd2d..2bf4e93672f1 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -388,6 +388,20 @@ inline String GetCodegenName(const std::string& composite_name) { return composite_name.substr(0, delim_pos); } +/* \brief Eliminate common subexpressions + * + * Utility for simplifying relax expressions by removing common + * subexpressions. + * + * \param expr The expression to be updated + * + * \param call_only If true, only eliminate relax::Call nodes. If + * false, eliminate any common subexpressions. + * + * \ret The updated expression + */ +Expr EliminateCommonSubexpr(const Expr& expr, bool call_only = false); + } // namespace relax } // namespace tvm