diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index cf594a09a266..0dfb45577280 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -672,47 +672,79 @@ class EliminateIdentityRewrite : public DFPatternRewrite { DFPattern const_; }; -/*! \brief Make two consecutive add able to be constant_folded. - * This pattern matching supports commutative property for addition. +/*! \brief Switch adjacent add-mul with constants to mul-add. + * As mul-add pattern is more friendly to FoldScaleAxis. */ -class SimplifyConsecutiveAdd : public DFPatternRewrite { +class SwitchAddMultiply : public DFPatternRewrite { public: - SimplifyConsecutiveAdd() { + SwitchAddMultiply() { x_ = IsWildcard(); - const1_ = IsConstant(); - const2_ = IsConstant(); - DFPattern add_op = IsOp("add"); - pattern_ = add_op({add_op({x_, const1_}), const2_}); + c1_ = IsConstant(); + c2_ = IsConstant(); + pattern_ = (x_ + c1_) * c2_; } Expr Callback(const Expr& pre, const Expr& post, const Map>& node_map) const override { - const CallNode* call = pre.as(); auto x = node_map[x_][0]; - auto c1 = node_map[const1_][0]; - auto c2 = node_map[const2_][0]; + auto c1 = node_map[c1_][0]; + auto c2 = node_map[c2_][0]; - auto pre_call = call; - // Find the next add call. - if (pre_call->args[1].as()) { - pre_call = pre_call->args[0].as(); - } else { - pre_call = pre_call->args[1].as(); + if (x.as()) { + return post; } - // Do nothing if both inputs are not constants as they will be constant folded already. - if (pre_call->args[0].as() && pre_call->args[1].as()) { + + Expr const_expr = Call(Op::Get("multiply"), {c1, c2}); + IRModule const_mod = IRModule::FromExpr(const_expr); + const_mod = transform::FoldConstant()(const_mod); + GlobalVar const_main = const_mod->GetGlobalVar("main"); + Expr const_val = Downcast(const_mod->functions[const_main])->body; + + return Call(Op::Get("add"), {Call(Op::Get("multiply"), {x, c2}), const_val}); + } + + private: + DFPattern x_; + DFPattern c1_; + DFPattern c2_; +}; + +/*! \brief Simplify two adjacent multiply or add with constants for further constant folding. + * The pattern matching supports commutative property. + */ +class SimplifyAdjacentMultiplyOrAdd : public DFPatternRewrite { + public: + SimplifyAdjacentMultiplyOrAdd() { + x_ = IsWildcard(); + c1_ = IsConstant(); + c2_ = IsConstant(); + pattern_ = (x_ * c1_ * c2_) || (x_ + c1_ + c2_); + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + const CallNode* call = pre.as(); + auto x = node_map[x_][0]; + auto c1 = node_map[c1_][0]; + auto c2 = node_map[c2_][0]; + + if (x.as()) { return post; - } else { - auto add_res = Call(call->op, {c1, c2}); - return Call(call->op, {x, add_res}); } - return post; + + Expr const_expr = Call(call->op, {c1, c2}); + IRModule const_mod = IRModule::FromExpr(const_expr); + const_mod = transform::FoldConstant()(const_mod); + GlobalVar const_main = const_mod->GetGlobalVar("main"); + Expr const_val = Downcast(const_mod->functions[const_main])->body; + + return Call(call->op, {x, const_val}); } private: DFPattern x_; - DFPattern const1_; - DFPattern const2_; + DFPattern c1_; + DFPattern c2_; }; /*! \brief Simplifying x/sqrt to x*sqrt */ @@ -800,7 +832,8 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); - composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index e84d238aaa75..8d5ea28ade61 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -560,46 +560,75 @@ def test_concretize_multiple(): assert tvm.ir.structural_equal(actual, expected) -def test_simplify_consecutive_add(): - shape = (32, 1, 1) - c_data = np.empty(shape).astype("float32") - c1 = relay.const(c_data) - c2 = relay.const(c_data) - - def before_const_right(): - x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32") - w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32") - y = relay.nn.conv2d(x, w, padding=(1, 1)) - y = relay.add(y, c1) - y = relay.add(y, c2) - y = relay.nn.relu(y) - return relay.Function([x, w], y) - - def before_const_left(): - x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32") - w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32") - y = relay.nn.conv2d(x, w, padding=(1, 1)) - y = relay.add(c1, y) - y = relay.add(c2, y) - y = relay.nn.relu(y) - return relay.Function([x, w], y) - - def expected(): - x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32") - w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32") - y = relay.nn.conv2d(x, w, padding=(1, 1)) - c3 = relay.add(c1, c2) - y = relay.add(y, c3) - y = relay.nn.relu(y) - return relay.Function([x, w], y) - - zr = before_const_right() - zl = before_const_left() - zzr = run_opt_pass(zr, transform.SimplifyExpr()) - zzl = run_opt_pass(zl, transform.SimplifyExpr()) - after = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(zzr, after) - assert tvm.ir.structural_equal(zzl, after) +def test_simplify_mul_add(): + def check_simple_fold(origin_exprs, expect_expr): + for origin_expr in origin_exprs: + simple_expr = run_opt_pass(origin_expr, transform.SimplifyExpr()) + assert tvm.ir.structural_equal(simple_expr, expect_expr) + + n = 32 + c1_val = np.random.uniform(size=n).astype("float32") + c2_val = np.random.uniform(size=n).astype("float32") + c3_val = np.random.uniform(size=n).astype("float32") + + x = relay.var("x", shape=(n,), dtype="float32") + c1 = relay.const(c1_val) + c2 = relay.const(c2_val) + c3 = relay.const(c3_val) + + # add-add -> add + origin_exprs = [ + x + c1 + c2, + c1 + x + c2, + ] + expect_expr = x + relay.const(c1_val + c2_val) + check_simple_fold(origin_exprs, expect_expr) + + # mul-mul -> mul + origin_exprs = [ + x * c1 * c2, + c1 * x * c2, + ] + expect_expr = x * relay.const(c1_val * c2_val) + check_simple_fold(origin_exprs, expect_expr) + + # add-mul -> mul-add + origin_exprs = [ + (x + c1) * c2, + (c1 + x) * c2, + c2 * (x + c1), + c2 * (c1 + x), + ] + expect_expr = x * c2 + relay.const(c1_val * c2_val) + check_simple_fold(origin_exprs, expect_expr) + + # add-mul-add -> mul-add + origin_exprs = [ + (x + c1) * c2 + c3, + (c1 + x) * c2 + c3, + c2 * (x + c1) + c3, + c2 * (c1 + x) + c3, + c3 + (x + c1) * c2, + c3 + (c1 + x) * c2, + c3 + c2 * (x + c1), + c3 + c2 * (c1 + x), + ] + expect_expr = x * c2 + relay.const(c1_val * c2_val + c3_val) + check_simple_fold(origin_exprs, expect_expr) + + # mul-add-mul -> mul-add + origin_exprs = [ + (x * c1 + c2) * c3, + (c1 * x + c2) * c3, + (c2 + x * c1) * c3, + (c2 + c1 * x) * c3, + c3 * (x * c1 + c2), + c3 * (c1 * x + c2), + c3 * (c2 + x * c1), + c3 * (c2 + c1 * x), + ] + expect_expr = x * relay.const(c1_val * c3_val) + relay.const(c2_val * c3_val) + check_simple_fold(origin_exprs, expect_expr) def test_simplify_rsqrt():