From a396a3a4f0b5cb1abef5e4644a2f8c93af9a99be Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Thu, 27 Oct 2022 16:01:11 +0800 Subject: [PATCH 1/2] simplify adjacent muls and adds with constants --- src/relay/transforms/simplify_expr.cc | 130 ++++++++++++++---- tests/python/relay/test_pass_simplify_expr.py | 119 ++++++++++------ 2 files changed, 182 insertions(+), 67 deletions(-) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index cf594a09a266..b4a2c2c817a1 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -672,47 +672,122 @@ class EliminateIdentityRewrite : public DFPatternRewrite { DFPattern const_; }; -/*! \brief Make two consecutive add able to be constant_folded. - * This pattern matching supports commutative property for addition. +/*! + * \brief Returns whether \p expr is a ConstantNode or is a Call with + * \p IsConstantExpr args. + */ +bool IsConstantExpr(const Expr& expr) { + if (expr.as()) { + return true; + } else if (const CallNode* call = expr.as()) { + return std::all_of(call->args.begin(), call->args.end(), IsConstantExpr); + } else { + return false; + } +} + +/*! \brief Switch adjacent add-mul with constants to mul-add. + * As mul-add pattern is more friendly to FoldScaleAxis. */ -class SimplifyConsecutiveAdd : public DFPatternRewrite { +class SwitchAddMultiply : public DFPatternRewrite { public: - SimplifyConsecutiveAdd() { - x_ = IsWildcard(); - const1_ = IsConstant(); - const2_ = IsConstant(); - DFPattern add_op = IsOp("add"); - pattern_ = add_op({add_op({x_, const1_}), const2_}); + SwitchAddMultiply() { + a_ = IsWildcard(); + b_ = IsWildcard(); + c_ = IsWildcard(); + pattern_ = (a_ + b_) * c_; } Expr Callback(const Expr& pre, const Expr& post, const Map>& node_map) const override { - const CallNode* call = pre.as(); - auto x = node_map[x_][0]; - auto c1 = node_map[const1_][0]; - auto c2 = node_map[const2_][0]; + auto a = node_map[a_][0]; + auto b = node_map[b_][0]; + auto c = node_map[c_][0]; + + bool is_a_const = IsConstantExpr(a); + bool is_b_const = IsConstantExpr(b); + bool is_c_const = IsConstantExpr(c); - auto pre_call = call; - // Find the next add call. - if (pre_call->args[1].as()) { - pre_call = pre_call->args[0].as(); + if (!is_c_const) { + return post; + } + if (is_a_const && is_b_const) { + return post; + } + + Expr x, c_add, c_mul; + c_mul = c; + if (is_a_const) { + x = b; + c_add = a; + } else if (is_b_const) { + x = a; + c_add = b; } else { - pre_call = pre_call->args[1].as(); + return post; } - // Do nothing if both inputs are not constants as they will be constant folded already. - if (pre_call->args[0].as() && pre_call->args[1].as()) { + + auto bias = Call(Op::Get("multiply"), {c_add, c_mul}); + return Call(Op::Get("add"), {Call(Op::Get("multiply"), {x, c_mul}), bias}); + } + + private: + DFPattern a_; + DFPattern b_; + DFPattern c_; +}; + +/*! \brief Simplify two adjacent multiply or add with constants for further constant folding. + * The pattern matching supports commutative property. + */ +class SimplifyAdjacentMultiplyOrAdd : public DFPatternRewrite { + public: + SimplifyAdjacentMultiplyOrAdd() { + a_ = IsWildcard(); + b_ = IsWildcard(); + c_ = IsWildcard(); + pattern_ = (a_ * b_ * c_) || (a_ + b_ + c_); + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + const CallNode* call = pre.as(); + auto a = node_map[a_][0]; + auto b = node_map[b_][0]; + auto c = node_map[c_][0]; + + bool is_a_const = IsConstantExpr(a); + bool is_b_const = IsConstantExpr(b); + bool is_c_const = IsConstantExpr(c); + + if (is_a_const && is_b_const && is_c_const) { return post; + } + + Expr x, c1, c2; + if (is_a_const && is_b_const) { + x = c; + c1 = a; + c2 = b; + } else if (is_a_const && is_c_const) { + x = b; + c1 = a; + c2 = c; + } else if (is_b_const && is_c_const) { + x = a; + c1 = b; + c2 = c; } else { - auto add_res = Call(call->op, {c1, c2}); - return Call(call->op, {x, add_res}); + return post; } - return post; + auto const_res = Call(call->op, {c1, c2}); + return Call(call->op, {x, const_res}); } private: - DFPattern x_; - DFPattern const1_; - DFPattern const2_; + DFPattern a_; + DFPattern b_; + DFPattern c_; }; /*! \brief Simplifying x/sqrt to x*sqrt */ @@ -800,7 +875,8 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); - composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index e84d238aaa75..a930f45a07ac 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -560,46 +560,85 @@ def test_concretize_multiple(): assert tvm.ir.structural_equal(actual, expected) -def test_simplify_consecutive_add(): - shape = (32, 1, 1) - c_data = np.empty(shape).astype("float32") - c1 = relay.const(c_data) - c2 = relay.const(c_data) - - def before_const_right(): - x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32") - w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32") - y = relay.nn.conv2d(x, w, padding=(1, 1)) - y = relay.add(y, c1) - y = relay.add(y, c2) - y = relay.nn.relu(y) - return relay.Function([x, w], y) - - def before_const_left(): - x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32") - w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32") - y = relay.nn.conv2d(x, w, padding=(1, 1)) - y = relay.add(c1, y) - y = relay.add(c2, y) - y = relay.nn.relu(y) - return relay.Function([x, w], y) - - def expected(): - x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32") - w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32") - y = relay.nn.conv2d(x, w, padding=(1, 1)) - c3 = relay.add(c1, c2) - y = relay.add(y, c3) - y = relay.nn.relu(y) - return relay.Function([x, w], y) - - zr = before_const_right() - zl = before_const_left() - zzr = run_opt_pass(zr, transform.SimplifyExpr()) - zzl = run_opt_pass(zl, transform.SimplifyExpr()) - after = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(zzr, after) - assert tvm.ir.structural_equal(zzl, after) +def test_simplify_mul_add(): + def check_simple_fold(origin_exprs, expect_expr, expect_fold): + for origin_expr in origin_exprs: + simple_expr = run_opt_pass(origin_expr, transform.SimplifyExpr()) + assert tvm.ir.structural_equal(simple_expr, expect_expr) + + fold_expr = run_opt_pass(simple_expr, transform.FoldConstant()) + assert tvm.ir.structural_equal(fold_expr, expect_fold) + + n = 32 + c1_val = np.random.uniform(size=n).astype("float32") + c2_val = np.random.uniform(size=n).astype("float32") + c3_val = np.random.uniform(size=n).astype("float32") + + x = relay.var("x", shape=(n,), dtype="float32") + c1 = relay.const(c1_val) + c2 = relay.const(c2_val) + c3 = relay.const(c3_val) + + # add-add -> add + origin_exprs = [ + x + c1 + c2, + c1 + x + c2, + c1 + c2 + x, + ] + expect_expr = x + (c1 + c2) + expect_fold = x + relay.const(c1_val + c2_val) + check_simple_fold(origin_exprs, expect_expr, expect_fold) + + # mul-mul -> mul + origin_exprs = [ + x * c1 * c2, + c1 * x * c2, + c1 * c2 * x, + ] + expect_expr = x * (c1 * c2) + expect_fold = x * relay.const(c1_val * c2_val) + check_simple_fold(origin_exprs, expect_expr, expect_fold) + + # add-mul -> mul-add + origin_exprs = [ + (x + c1) * c2, + (c1 + x) * c2, + c2 * (x + c1), + c2 * (c1 + x), + ] + expect_expr = x * c2 + c1 * c2 + expect_fold = x * c2 + relay.const(c1_val * c2_val) + check_simple_fold(origin_exprs, expect_expr, expect_fold) + + # add-mul-add -> mul-add + origin_exprs = [ + (x + c1) * c2 + c3, + (c1 + x) * c2 + c3, + c2 * (x + c1) + c3, + c2 * (c1 + x) + c3, + c3 + (x + c1) * c2, + c3 + (c1 + x) * c2, + c3 + c2 * (x + c1), + c3 + c2 * (c1 + x), + ] + expect_expr = x * c2 + (c1 * c2 + c3) + expect_fold = x * c2 + relay.const(c1_val * c2_val + c3_val) + check_simple_fold(origin_exprs, expect_expr, expect_fold) + + # mul-add-mul -> mul-add + origin_exprs = [ + (x * c1 + c2) * c3, + (c1 * x + c2) * c3, + (c2 + x * c1) * c3, + (c2 + c1 * x) * c3, + c3 * (x * c1 + c2), + c3 * (c1 * x + c2), + c3 * (c2 + x * c1), + c3 * (c2 + c1 * x), + ] + expect_expr = x * (c1 * c3) + c2 * c3 + expect_fold = x * relay.const(c1_val * c3_val) + relay.const(c2_val * c3_val) + check_simple_fold(origin_exprs, expect_expr, expect_fold) def test_simplify_rsqrt(): From 14d8bbae8264edfa44142b646ed6afc561dc9cf1 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Fri, 28 Oct 2022 15:29:17 +0800 Subject: [PATCH 2/2] apply FoldConstant inside SimplifyExpr --- src/relay/transforms/simplify_expr.cc | 113 ++++++------------ tests/python/relay/test_pass_simplify_expr.py | 32 ++--- 2 files changed, 46 insertions(+), 99 deletions(-) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index b4a2c2c817a1..0dfb45577280 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -672,69 +672,41 @@ class EliminateIdentityRewrite : public DFPatternRewrite { DFPattern const_; }; -/*! - * \brief Returns whether \p expr is a ConstantNode or is a Call with - * \p IsConstantExpr args. - */ -bool IsConstantExpr(const Expr& expr) { - if (expr.as()) { - return true; - } else if (const CallNode* call = expr.as()) { - return std::all_of(call->args.begin(), call->args.end(), IsConstantExpr); - } else { - return false; - } -} - /*! \brief Switch adjacent add-mul with constants to mul-add. * As mul-add pattern is more friendly to FoldScaleAxis. */ class SwitchAddMultiply : public DFPatternRewrite { public: SwitchAddMultiply() { - a_ = IsWildcard(); - b_ = IsWildcard(); - c_ = IsWildcard(); - pattern_ = (a_ + b_) * c_; + x_ = IsWildcard(); + c1_ = IsConstant(); + c2_ = IsConstant(); + pattern_ = (x_ + c1_) * c2_; } Expr Callback(const Expr& pre, const Expr& post, const Map>& node_map) const override { - auto a = node_map[a_][0]; - auto b = node_map[b_][0]; - auto c = node_map[c_][0]; - - bool is_a_const = IsConstantExpr(a); - bool is_b_const = IsConstantExpr(b); - bool is_c_const = IsConstantExpr(c); + auto x = node_map[x_][0]; + auto c1 = node_map[c1_][0]; + auto c2 = node_map[c2_][0]; - if (!is_c_const) { - return post; - } - if (is_a_const && is_b_const) { + if (x.as()) { return post; } - Expr x, c_add, c_mul; - c_mul = c; - if (is_a_const) { - x = b; - c_add = a; - } else if (is_b_const) { - x = a; - c_add = b; - } else { - return post; - } + Expr const_expr = Call(Op::Get("multiply"), {c1, c2}); + IRModule const_mod = IRModule::FromExpr(const_expr); + const_mod = transform::FoldConstant()(const_mod); + GlobalVar const_main = const_mod->GetGlobalVar("main"); + Expr const_val = Downcast(const_mod->functions[const_main])->body; - auto bias = Call(Op::Get("multiply"), {c_add, c_mul}); - return Call(Op::Get("add"), {Call(Op::Get("multiply"), {x, c_mul}), bias}); + return Call(Op::Get("add"), {Call(Op::Get("multiply"), {x, c2}), const_val}); } private: - DFPattern a_; - DFPattern b_; - DFPattern c_; + DFPattern x_; + DFPattern c1_; + DFPattern c2_; }; /*! \brief Simplify two adjacent multiply or add with constants for further constant folding. @@ -743,51 +715,36 @@ class SwitchAddMultiply : public DFPatternRewrite { class SimplifyAdjacentMultiplyOrAdd : public DFPatternRewrite { public: SimplifyAdjacentMultiplyOrAdd() { - a_ = IsWildcard(); - b_ = IsWildcard(); - c_ = IsWildcard(); - pattern_ = (a_ * b_ * c_) || (a_ + b_ + c_); + x_ = IsWildcard(); + c1_ = IsConstant(); + c2_ = IsConstant(); + pattern_ = (x_ * c1_ * c2_) || (x_ + c1_ + c2_); } Expr Callback(const Expr& pre, const Expr& post, const Map>& node_map) const override { const CallNode* call = pre.as(); - auto a = node_map[a_][0]; - auto b = node_map[b_][0]; - auto c = node_map[c_][0]; - - bool is_a_const = IsConstantExpr(a); - bool is_b_const = IsConstantExpr(b); - bool is_c_const = IsConstantExpr(c); + auto x = node_map[x_][0]; + auto c1 = node_map[c1_][0]; + auto c2 = node_map[c2_][0]; - if (is_a_const && is_b_const && is_c_const) { + if (x.as()) { return post; } - Expr x, c1, c2; - if (is_a_const && is_b_const) { - x = c; - c1 = a; - c2 = b; - } else if (is_a_const && is_c_const) { - x = b; - c1 = a; - c2 = c; - } else if (is_b_const && is_c_const) { - x = a; - c1 = b; - c2 = c; - } else { - return post; - } - auto const_res = Call(call->op, {c1, c2}); - return Call(call->op, {x, const_res}); + Expr const_expr = Call(call->op, {c1, c2}); + IRModule const_mod = IRModule::FromExpr(const_expr); + const_mod = transform::FoldConstant()(const_mod); + GlobalVar const_main = const_mod->GetGlobalVar("main"); + Expr const_val = Downcast(const_mod->functions[const_main])->body; + + return Call(call->op, {x, const_val}); } private: - DFPattern a_; - DFPattern b_; - DFPattern c_; + DFPattern x_; + DFPattern c1_; + DFPattern c2_; }; /*! \brief Simplifying x/sqrt to x*sqrt */ diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index a930f45a07ac..8d5ea28ade61 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -561,14 +561,11 @@ def test_concretize_multiple(): def test_simplify_mul_add(): - def check_simple_fold(origin_exprs, expect_expr, expect_fold): + def check_simple_fold(origin_exprs, expect_expr): for origin_expr in origin_exprs: simple_expr = run_opt_pass(origin_expr, transform.SimplifyExpr()) assert tvm.ir.structural_equal(simple_expr, expect_expr) - fold_expr = run_opt_pass(simple_expr, transform.FoldConstant()) - assert tvm.ir.structural_equal(fold_expr, expect_fold) - n = 32 c1_val = np.random.uniform(size=n).astype("float32") c2_val = np.random.uniform(size=n).astype("float32") @@ -583,21 +580,17 @@ def check_simple_fold(origin_exprs, expect_expr, expect_fold): origin_exprs = [ x + c1 + c2, c1 + x + c2, - c1 + c2 + x, ] - expect_expr = x + (c1 + c2) - expect_fold = x + relay.const(c1_val + c2_val) - check_simple_fold(origin_exprs, expect_expr, expect_fold) + expect_expr = x + relay.const(c1_val + c2_val) + check_simple_fold(origin_exprs, expect_expr) # mul-mul -> mul origin_exprs = [ x * c1 * c2, c1 * x * c2, - c1 * c2 * x, ] - expect_expr = x * (c1 * c2) - expect_fold = x * relay.const(c1_val * c2_val) - check_simple_fold(origin_exprs, expect_expr, expect_fold) + expect_expr = x * relay.const(c1_val * c2_val) + check_simple_fold(origin_exprs, expect_expr) # add-mul -> mul-add origin_exprs = [ @@ -606,9 +599,8 @@ def check_simple_fold(origin_exprs, expect_expr, expect_fold): c2 * (x + c1), c2 * (c1 + x), ] - expect_expr = x * c2 + c1 * c2 - expect_fold = x * c2 + relay.const(c1_val * c2_val) - check_simple_fold(origin_exprs, expect_expr, expect_fold) + expect_expr = x * c2 + relay.const(c1_val * c2_val) + check_simple_fold(origin_exprs, expect_expr) # add-mul-add -> mul-add origin_exprs = [ @@ -621,9 +613,8 @@ def check_simple_fold(origin_exprs, expect_expr, expect_fold): c3 + c2 * (x + c1), c3 + c2 * (c1 + x), ] - expect_expr = x * c2 + (c1 * c2 + c3) - expect_fold = x * c2 + relay.const(c1_val * c2_val + c3_val) - check_simple_fold(origin_exprs, expect_expr, expect_fold) + expect_expr = x * c2 + relay.const(c1_val * c2_val + c3_val) + check_simple_fold(origin_exprs, expect_expr) # mul-add-mul -> mul-add origin_exprs = [ @@ -636,9 +627,8 @@ def check_simple_fold(origin_exprs, expect_expr, expect_fold): c3 * (c2 + x * c1), c3 * (c2 + c1 * x), ] - expect_expr = x * (c1 * c3) + c2 * c3 - expect_fold = x * relay.const(c1_val * c3_val) + relay.const(c2_val * c3_val) - check_simple_fold(origin_exprs, expect_expr, expect_fold) + expect_expr = x * relay.const(c1_val * c3_val) + relay.const(c2_val * c3_val) + check_simple_fold(origin_exprs, expect_expr) def test_simplify_rsqrt():