Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions src/relax/transform/eliminate_common_subexpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <tvm/relax/transform.h>
#include <tvm/relax/utils.h>

#include "utils.h"

namespace tvm {
namespace relax {

Expand Down Expand Up @@ -74,6 +76,12 @@ class ImpurityDetector : public ExprVisitor {

class SubexprCounter : public ExprVisitor {
public:
static std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(const Expr& expr) {
SubexprCounter visitor;
visitor(expr);
return visitor.count_map_;
}

// overriding VisitExpr ensures we do this for every subexpression
void VisitExpr(const Expr& e) override {
// Cases we ignore because we will not substitute them:
Expand Down Expand Up @@ -106,25 +114,17 @@ class SubexprCounter : public ExprVisitor {
// we are not going to do replacements inside struct info to avoid binding lots of reused shapes
void VisitExprDepStructInfoField(const StructInfo& struct_info) override {}

std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(const Function& func) {
VisitExpr(func->body);
return count_map_;
}

private:
std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
ImpurityDetector impurity_detector_;
};

// forward declaration
Function EliminateCommonSubexpr(const Function&, bool call_only);

class CommonSubexprEliminator : public ExprMutator {
public:
explicit CommonSubexprEliminator(
const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>& count_map,
std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map,
bool call_only = false)
: count_map_(count_map), call_only_(call_only) {}
: count_map_(std::move(count_map)), call_only_(call_only) {}

// overriding here ensures we visit every subexpression
Expr VisitExpr(const Expr& e) override {
Expand All @@ -151,9 +151,15 @@ class CommonSubexprEliminator : public ExprMutator {
return struct_info;
}

Expr VisitExpr_(const FunctionNode* func) override {
// do full CSE within the function
return EliminateCommonSubexpr(GetRef<Function>(func), call_only_);
Expr VisitExpr_(const FunctionNode* op) override {
Function func = GetRef<Function>(op);

auto cache = SubexprCounter::Count(op->body);
std::swap(cache, count_map_);
Expr output = ExprMutator::VisitExpr_(op);
std::swap(cache, count_map_);

return output;
}

void VisitBinding_(const VarBindingNode* binding) override {
Expand Down Expand Up @@ -203,17 +209,14 @@ class CommonSubexprEliminator : public ExprMutator {
return VisitExpr(bound_value);
}

const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>& count_map_;
std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
std::unordered_map<Expr, Var, StructuralHash, StructuralEqual> replacements_;
bool call_only_{false};
};

Function EliminateCommonSubexpr(const Function& func, bool call_only) {
SubexprCounter counter;
auto count_map = counter.Count(func);
CommonSubexprEliminator eliminator(count_map, call_only);
return Function(func->params, eliminator.VisitExpr(func->body), func->ret_struct_info,
func->is_pure, func->attrs, func->span);
Expr EliminateCommonSubexpr(const Expr& expr, bool call_only) {
CommonSubexprEliminator mutator(SubexprCounter::Count(expr), call_only);
return mutator(expr);
}

namespace transform {
Expand Down
14 changes: 14 additions & 0 deletions src/relax/transform/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,20 @@ inline String GetCodegenName(const std::string& composite_name) {
return composite_name.substr(0, delim_pos);
}

/* \brief Eliminate common subexpressions
*
* Utility for simplifying relax expressions by removing common
* subexpressions.
*
* \param expr The expression to be updated
*
* \param call_only If true, only eliminate relax::Call nodes. If
* false, eliminate any common subexpressions.
*
* \ret The updated expression
*/
Expr EliminateCommonSubexpr(const Expr& expr, bool call_only = false);

} // namespace relax
} // namespace tvm

Expand Down