Skip to content

Commit 01599d1

Browse files
author
Ivy Zhang
authored
[SimplifyExpr] Simplify consecutive adds with constants (#9671)
1 parent bd361b9 commit 01599d1

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

src/relay/transforms/simplify_expr.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,49 @@ class EliminateIdentityRewrite : public DFPatternRewrite {
585585
DFPattern const_;
586586
};
587587

588+
/*! \brief Make two consecutive add able to be constant_folded.
589+
* This pattern matching supports commutative property for addition.
590+
*/
591+
class SimplifyConsecutiveAdd : public DFPatternRewrite {
592+
public:
593+
SimplifyConsecutiveAdd() {
594+
x_ = IsWildcard();
595+
const1_ = IsConstant();
596+
const2_ = IsConstant();
597+
DFPattern add_op = IsOp("add");
598+
pattern_ = add_op({add_op({x_, const1_}), const2_});
599+
}
600+
601+
Expr Callback(const Expr& pre, const Expr& post,
602+
const Map<DFPattern, Array<Expr>>& node_map) const override {
603+
const CallNode* call = pre.as<CallNode>();
604+
auto x = node_map[x_][0];
605+
auto c1 = node_map[const1_][0];
606+
auto c2 = node_map[const2_][0];
607+
608+
auto pre_call = call;
609+
// Find the next add call.
610+
if (pre_call->args[1].as<ConstantNode>()) {
611+
pre_call = pre_call->args[0].as<CallNode>();
612+
} else {
613+
pre_call = pre_call->args[1].as<CallNode>();
614+
}
615+
// Do nothing if both inputs are not constants as they will be constant folded already.
616+
if (pre_call->args[0].as<ConstantNode>() && pre_call->args[1].as<ConstantNode>()) {
617+
return post;
618+
} else {
619+
auto add_res = Call(call->op, {c1, c2});
620+
return Call(call->op, {x, add_res});
621+
}
622+
return post;
623+
}
624+
625+
private:
626+
DFPattern x_;
627+
DFPattern const1_;
628+
DFPattern const2_;
629+
};
630+
588631
Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
589632
// the rewrites will be applied in the given order, and repeated until fixed point
590633
DFPatternRewriteComposer composer;
@@ -599,6 +642,7 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
599642
composer.AddRewrite<SimplifyTranspose>();
600643
composer.AddRewrite<SimplifyCast>();
601644
composer.AddRewrite<FullElementwise>();
645+
composer.AddRewrite<SimplifyConsecutiveAdd>();
602646
return RewritePatterns(composer.MakeCallbacks(), expr, mod);
603647
}
604648

tests/python/relay/test_pass_simplify_expr.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,5 +512,47 @@ def test_concretize_multiple():
512512
assert tvm.ir.structural_equal(actual, expected)
513513

514514

515+
def test_simplify_consecutive_add():
516+
shape = (32, 1, 1)
517+
c_data = np.empty(shape).astype("float32")
518+
c1 = relay.const(c_data)
519+
c2 = relay.const(c_data)
520+
521+
def before_const_right():
522+
x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
523+
w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32")
524+
y = relay.nn.conv2d(x, w, padding=(1, 1))
525+
y = relay.add(y, c1)
526+
y = relay.add(y, c2)
527+
y = relay.nn.relu(y)
528+
return relay.Function([x, w], y)
529+
530+
def before_const_left():
531+
x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
532+
w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32")
533+
y = relay.nn.conv2d(x, w, padding=(1, 1))
534+
y = relay.add(c1, y)
535+
y = relay.add(c2, y)
536+
y = relay.nn.relu(y)
537+
return relay.Function([x, w], y)
538+
539+
def expected():
540+
x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
541+
w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32")
542+
y = relay.nn.conv2d(x, w, padding=(1, 1))
543+
c3 = relay.add(c1, c2)
544+
y = relay.add(y, c3)
545+
y = relay.nn.relu(y)
546+
return relay.Function([x, w], y)
547+
548+
zr = before_const_right()
549+
zl = before_const_left()
550+
zzr = run_opt_pass(zr, transform.SimplifyExpr())
551+
zzl = run_opt_pass(zl, transform.SimplifyExpr())
552+
after = run_opt_pass(expected(), transform.InferType())
553+
assert tvm.ir.structural_equal(zzr, after)
554+
assert tvm.ir.structural_equal(zzl, after)
555+
556+
515557
if __name__ == "__main__":
516558
pytest.main([__file__])

0 commit comments

Comments
 (0)