From 56dd4d7e26617a8db749571ccf09ef8ad7f3bae9 Mon Sep 17 00:00:00 2001 From: hgt312 Date: Mon, 12 Apr 2021 16:45:19 +0800 Subject: [PATCH 1/4] init --- src/relay/transforms/simplify_expr.cc | 67 +++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 762aa58f7298..8c2cbf59b533 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -75,6 +75,59 @@ class SimplifyReshape : public DFPatternRewrite { DFPattern x_; }; +/*! + * \brief SimplifyCastLike matches the pattern of cast data to the same dtype. + */ +class SimplifyCastLike : public DFPatternRewrite { + public: + explicit SimplifyCastLike() { + data_pat_ = IsWildcard(); + like_pat_ = IsWildcard(); + pattern_ = IsOp("cast_like")({data_pat_, like_pat_}); + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + auto data = node_map[data_pat_][0]; + const TensorTypeNode* data_ty = data->checked_type().as(); + const TensorTypeNode* like_ty = pre->checked_type().as(); + if (like_ty->dtype == data_ty->dtype) { + return data; + } + return post; + } + + protected: + DFPattern data_pat_; + DFPattern like_pat_; +}; + +/*! + * \brief SimplifyCast matches the pattern of cast data to the same dtype. + */ +class SimplifyCast : public DFPatternRewrite { + public: + explicit SimplifyCast() { + data_pat_ = IsWildcard(); + pattern_ = IsOp("cast")({data_pat_}); + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + const CallNode* call = pre.as(); + auto attrs = call->attrs.as(); + auto data = node_map[data_pat_][0]; + const TensorTypeNode* data_ty = data->checked_type().as(); + if (attrs->dtype == data_ty->dtype) { + return data; + } + return post; + } + + protected: + DFPattern data_pat_; +}; + /*! * \brief SimplifyTranspose matches the pattern of consecutive transpose op, * and merges or cancels them. @@ -321,6 +374,17 @@ class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite { } }; +class ConcretizeFullLikeRewrite : public ConcretizeLikeRewrite { + public: + ConcretizeFullLikeRewrite() : ConcretizeLikeRewrite(Op::Get("full_like")) {} + + Expr Concretize(const Map>& node_map, Array shape, + DataType dtype) const override { + // `like_pat_` here is `fill_value` + return MakeFull(node_map[like_pat_][0], shape, dtype); + } +}; + class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite { public: ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {} @@ -439,12 +503,15 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { DFPatternRewriteComposer composer; composer.AddRewrite(); composer.AddRewrite(); + composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); + composer.AddRewrite(); + composer.AddRewrite(); composer.AddRewrite(); return RewritePatterns(composer.MakeCallbacks(), expr, mod); } From b154cfed7fdaf424d198798b193721a164966c03 Mon Sep 17 00:00:00 2001 From: hgt312 Date: Mon, 12 Apr 2021 16:45:28 +0800 Subject: [PATCH 2/4] test --- tests/python/relay/test_pass_simplify_expr.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index d015cdd36c2d..7afc698c954b 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -236,6 +236,27 @@ def check(x, y=None, do_nothing=False): check(id_op(const, x), id_op(op_like(x), x)) +def test_simplify_cast_like(): + dtype = "int32" + data = relay.var("data", shape=(3, 4, 5), dtype=dtype) + dtype_like = relay.var("dtype_like", shape=(2, 2, 2), dtype=dtype) + expr = relay.cast_like(data, dtype_like) + + expected = run_infer_type(data) + actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + + +def test_simplify_cast(): + dtype = "int32" + data = relay.var("data", shape=(3, 4, 5), dtype=dtype) + expr = relay.cast(data, dtype) + + expected = run_infer_type(data) + actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + + def test_concretize_reshape_like(): data = relay.var("data", shape=(2, 3, 4), dtype="float32") shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32") @@ -276,6 +297,17 @@ def test_concretize_ones_like(): assert tvm.ir.structural_equal(actual, expected) +def test_concretize_full_like(): + dtype = "int32" + shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype) + fill_value = relay.var("fill", relay.TensorType((), "float32")) + expr = relay.full_like(shape_like, fill_value) + + expected = run_infer_type(relay.full(fill_value, (3, 4, 5), dtype)) + actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual, expected) + + def test_concretize_collapse_sum_like(): data = relay.var("data", shape=(3, 3, 3), dtype="float32") shape_like = relay.var("shape_like", shape=(3,), dtype="float32") From 9132cb638a8271ab14e9f679030b2afffccc411c Mon Sep 17 00:00:00 2001 From: hgt312 Date: Mon, 12 Apr 2021 17:02:42 +0800 Subject: [PATCH 3/4] lint --- src/relay/transforms/simplify_expr.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 8c2cbf59b533..1086ddfdee59 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -80,7 +80,7 @@ class SimplifyReshape : public DFPatternRewrite { */ class SimplifyCastLike : public DFPatternRewrite { public: - explicit SimplifyCastLike() { + SimplifyCastLike() { data_pat_ = IsWildcard(); like_pat_ = IsWildcard(); pattern_ = IsOp("cast_like")({data_pat_, like_pat_}); @@ -107,7 +107,7 @@ class SimplifyCastLike : public DFPatternRewrite { */ class SimplifyCast : public DFPatternRewrite { public: - explicit SimplifyCast() { + SimplifyCast() { data_pat_ = IsWildcard(); pattern_ = IsOp("cast")({data_pat_}); } From 0cb658cd91d24ee9c55c0390266e80c496dd631f Mon Sep 17 00:00:00 2001 From: hgt312 Date: Tue, 13 Apr 2021 12:22:18 +0800 Subject: [PATCH 4/4] try fix --- src/relay/transforms/simplify_expr.cc | 41 ++++--------------- tests/python/relay/test_pass_simplify_expr.py | 21 ++++------ 2 files changed, 14 insertions(+), 48 deletions(-) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 1086ddfdee59..5662ef5b45a6 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -75,33 +75,6 @@ class SimplifyReshape : public DFPatternRewrite { DFPattern x_; }; -/*! - * \brief SimplifyCastLike matches the pattern of cast data to the same dtype. - */ -class SimplifyCastLike : public DFPatternRewrite { - public: - SimplifyCastLike() { - data_pat_ = IsWildcard(); - like_pat_ = IsWildcard(); - pattern_ = IsOp("cast_like")({data_pat_, like_pat_}); - } - - Expr Callback(const Expr& pre, const Expr& post, - const Map>& node_map) const override { - auto data = node_map[data_pat_][0]; - const TensorTypeNode* data_ty = data->checked_type().as(); - const TensorTypeNode* like_ty = pre->checked_type().as(); - if (like_ty->dtype == data_ty->dtype) { - return data; - } - return post; - } - - protected: - DFPattern data_pat_; - DFPattern like_pat_; -}; - /*! * \brief SimplifyCast matches the pattern of cast data to the same dtype. */ @@ -109,23 +82,24 @@ class SimplifyCast : public DFPatternRewrite { public: SimplifyCast() { data_pat_ = IsWildcard(); - pattern_ = IsOp("cast")({data_pat_}); + like_pat_ = IsWildcard(); + pattern_ = IsOp("cast_like")({data_pat_, like_pat_}) || IsOp("cast")({data_pat_}); } Expr Callback(const Expr& pre, const Expr& post, const Map>& node_map) const override { const CallNode* call = pre.as(); - auto attrs = call->attrs.as(); - auto data = node_map[data_pat_][0]; - const TensorTypeNode* data_ty = data->checked_type().as(); - if (attrs->dtype == data_ty->dtype) { - return data; + const TensorTypeNode* data_ty = call->args[0]->checked_type().as(); + const TensorTypeNode* like_ty = pre->checked_type().as(); + if (like_ty->dtype == data_ty->dtype) { + return node_map[data_pat_][0]; } return post; } protected: DFPattern data_pat_; + DFPattern like_pat_; }; /*! @@ -510,7 +484,6 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); - composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); return RewritePatterns(composer.MakeCallbacks(), expr, mod); diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 7afc698c954b..d1dffa34578b 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -236,25 +236,18 @@ def check(x, y=None, do_nothing=False): check(id_op(const, x), id_op(op_like(x), x)) -def test_simplify_cast_like(): - dtype = "int32" - data = relay.var("data", shape=(3, 4, 5), dtype=dtype) - dtype_like = relay.var("dtype_like", shape=(2, 2, 2), dtype=dtype) - expr = relay.cast_like(data, dtype_like) - - expected = run_infer_type(data) - actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) - assert tvm.ir.structural_equal(actual, expected) - - def test_simplify_cast(): dtype = "int32" data = relay.var("data", shape=(3, 4, 5), dtype=dtype) - expr = relay.cast(data, dtype) + expr1 = relay.cast(data, dtype) + dtype_like = relay.var("dtype_like", shape=(2, 2, 2), dtype=dtype) + expr2 = relay.cast_like(data, dtype_like) expected = run_infer_type(data) - actual = run_opt_pass(expr, relay.transform.SimplifyExpr()) - assert tvm.ir.structural_equal(actual, expected) + actual1 = run_opt_pass(expr1, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual1, expected) + actual2 = run_opt_pass(expr2, relay.transform.SimplifyExpr()) + assert tvm.ir.structural_equal(actual2, expected) def test_concretize_reshape_like():