Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 59 additions & 26 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -672,47 +672,79 @@ class EliminateIdentityRewrite : public DFPatternRewrite {
DFPattern const_;
};

/*! \brief Make two consecutive add able to be constant_folded.
* This pattern matching supports commutative property for addition.
/*! \brief Switch adjacent add-mul with constants to mul-add.
* As mul-add pattern is more friendly to FoldScaleAxis.
*/
class SimplifyConsecutiveAdd : public DFPatternRewrite {
class SwitchAddMultiply : public DFPatternRewrite {
public:
SimplifyConsecutiveAdd() {
SwitchAddMultiply() {
x_ = IsWildcard();
const1_ = IsConstant();
const2_ = IsConstant();
DFPattern add_op = IsOp("add");
pattern_ = add_op({add_op({x_, const1_}), const2_});
c1_ = IsConstant();
c2_ = IsConstant();
pattern_ = (x_ + c1_) * c2_;
}

Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
const CallNode* call = pre.as<CallNode>();
auto x = node_map[x_][0];
auto c1 = node_map[const1_][0];
auto c2 = node_map[const2_][0];
auto c1 = node_map[c1_][0];
auto c2 = node_map[c2_][0];

auto pre_call = call;
// Find the next add call.
if (pre_call->args[1].as<ConstantNode>()) {
pre_call = pre_call->args[0].as<CallNode>();
} else {
pre_call = pre_call->args[1].as<CallNode>();
if (x.as<ConstantNode>()) {
return post;
}
// Do nothing if both inputs are not constants as they will be constant folded already.
if (pre_call->args[0].as<ConstantNode>() && pre_call->args[1].as<ConstantNode>()) {

Expr const_expr = Call(Op::Get("multiply"), {c1, c2});
IRModule const_mod = IRModule::FromExpr(const_expr);
const_mod = transform::FoldConstant()(const_mod);
GlobalVar const_main = const_mod->GetGlobalVar("main");
Expr const_val = Downcast<Function>(const_mod->functions[const_main])->body;

return Call(Op::Get("add"), {Call(Op::Get("multiply"), {x, c2}), const_val});
}

private:
DFPattern x_;
DFPattern c1_;
DFPattern c2_;
};

/*! \brief Simplify two adjacent multiply or add with constants for further constant folding.
* The pattern matching supports commutative property.
*/
class SimplifyAdjacentMultiplyOrAdd : public DFPatternRewrite {
public:
SimplifyAdjacentMultiplyOrAdd() {
x_ = IsWildcard();
c1_ = IsConstant();
c2_ = IsConstant();
pattern_ = (x_ * c1_ * c2_) || (x_ + c1_ + c2_);
}

Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
const CallNode* call = pre.as<CallNode>();
auto x = node_map[x_][0];
auto c1 = node_map[c1_][0];
auto c2 = node_map[c2_][0];

if (x.as<ConstantNode>()) {
return post;
} else {
auto add_res = Call(call->op, {c1, c2});
return Call(call->op, {x, add_res});
}
return post;

Expr const_expr = Call(call->op, {c1, c2});
IRModule const_mod = IRModule::FromExpr(const_expr);
const_mod = transform::FoldConstant()(const_mod);
GlobalVar const_main = const_mod->GetGlobalVar("main");
Expr const_val = Downcast<Function>(const_mod->functions[const_main])->body;

return Call(call->op, {x, const_val});
}

private:
DFPattern x_;
DFPattern const1_;
DFPattern const2_;
DFPattern c1_;
DFPattern c2_;
};

/*! \brief Simplifying x/sqrt to x*sqrt */
Expand Down Expand Up @@ -800,7 +832,8 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
composer.AddRewrite<SimplifySameCast>();
composer.AddRewrite<SimplifyConsecutiveCast>();
composer.AddRewrite<FullElementwise>();
composer.AddRewrite<SimplifyConsecutiveAdd>();
composer.AddRewrite<SwitchAddMultiply>();
composer.AddRewrite<SimplifyAdjacentMultiplyOrAdd>();
composer.AddRewrite<SimplifyDQArgMax>();
composer.AddRewrite<SimplifyDQArgMin>();
composer.AddRewrite<SimplifyDQArgSort>();
Expand Down
109 changes: 69 additions & 40 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,46 +560,75 @@ def test_concretize_multiple():
assert tvm.ir.structural_equal(actual, expected)


def test_simplify_consecutive_add():
shape = (32, 1, 1)
c_data = np.empty(shape).astype("float32")
c1 = relay.const(c_data)
c2 = relay.const(c_data)

def before_const_right():
x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32")
y = relay.nn.conv2d(x, w, padding=(1, 1))
y = relay.add(y, c1)
y = relay.add(y, c2)
y = relay.nn.relu(y)
return relay.Function([x, w], y)

def before_const_left():
x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32")
y = relay.nn.conv2d(x, w, padding=(1, 1))
y = relay.add(c1, y)
y = relay.add(c2, y)
y = relay.nn.relu(y)
return relay.Function([x, w], y)

def expected():
x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32")
y = relay.nn.conv2d(x, w, padding=(1, 1))
c3 = relay.add(c1, c2)
y = relay.add(y, c3)
y = relay.nn.relu(y)
return relay.Function([x, w], y)

zr = before_const_right()
zl = before_const_left()
zzr = run_opt_pass(zr, transform.SimplifyExpr())
zzl = run_opt_pass(zl, transform.SimplifyExpr())
after = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(zzr, after)
assert tvm.ir.structural_equal(zzl, after)
def test_simplify_mul_add():
def check_simple_fold(origin_exprs, expect_expr):
for origin_expr in origin_exprs:
simple_expr = run_opt_pass(origin_expr, transform.SimplifyExpr())
assert tvm.ir.structural_equal(simple_expr, expect_expr)

n = 32
c1_val = np.random.uniform(size=n).astype("float32")
c2_val = np.random.uniform(size=n).astype("float32")
c3_val = np.random.uniform(size=n).astype("float32")

x = relay.var("x", shape=(n,), dtype="float32")
c1 = relay.const(c1_val)
c2 = relay.const(c2_val)
c3 = relay.const(c3_val)

# add-add -> add
origin_exprs = [
x + c1 + c2,
c1 + x + c2,
]
expect_expr = x + relay.const(c1_val + c2_val)
check_simple_fold(origin_exprs, expect_expr)

# mul-mul -> mul
origin_exprs = [
x * c1 * c2,
c1 * x * c2,
]
expect_expr = x * relay.const(c1_val * c2_val)
check_simple_fold(origin_exprs, expect_expr)

# add-mul -> mul-add
origin_exprs = [
(x + c1) * c2,
(c1 + x) * c2,
c2 * (x + c1),
c2 * (c1 + x),
]
expect_expr = x * c2 + relay.const(c1_val * c2_val)
check_simple_fold(origin_exprs, expect_expr)

# add-mul-add -> mul-add
origin_exprs = [
(x + c1) * c2 + c3,
(c1 + x) * c2 + c3,
c2 * (x + c1) + c3,
c2 * (c1 + x) + c3,
c3 + (x + c1) * c2,
c3 + (c1 + x) * c2,
c3 + c2 * (x + c1),
c3 + c2 * (c1 + x),
]
expect_expr = x * c2 + relay.const(c1_val * c2_val + c3_val)
check_simple_fold(origin_exprs, expect_expr)

# mul-add-mul -> mul-add
origin_exprs = [
(x * c1 + c2) * c3,
(c1 * x + c2) * c3,
(c2 + x * c1) * c3,
(c2 + c1 * x) * c3,
c3 * (x * c1 + c2),
c3 * (c1 * x + c2),
c3 * (c2 + x * c1),
c3 * (c2 + c1 * x),
]
expect_expr = x * relay.const(c1_val * c3_val) + relay.const(c2_val * c3_val)
check_simple_fold(origin_exprs, expect_expr)


def test_simplify_rsqrt():
Expand Down