From 0ba25441f59825edd0049fe8efd488f220798946 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 13 Jan 2021 16:26:35 +0900 Subject: [PATCH 1/5] Add if pattern commit 1ee052fd494a5bdd881c242c3ea0c95cf2a613e5 Author: Masahiro Masuda Date: Sat Dec 26 22:19:17 2020 +0900 add comment commit c846a6999e9c9e48fbc019780e705a990f46cb22 Author: Masahiro Masuda Date: Sat Dec 26 21:14:20 2020 +0900 max_out_size rewrite added to the test commit 2c7c7fbd0e6563aba694e7fb6baa7bda8e4fadca Author: Masahiro Masuda Date: Sat Dec 26 20:57:55 2020 +0900 max_out_size rewrite working commit 319e930acb8162c1ec4a5d4fb71d134580a68f13 Author: Masahiro Masuda Date: Sat Dec 26 20:43:16 2020 +0900 refactor dyn strided slice pattern commit fb6917b703440748800bde624bc20efaf5798b8a Author: Masahiro Masuda Date: Sat Dec 26 11:21:33 2020 +0900 update NMS pattern following frontend change commit 255a98f1da8f300d4fe417cce3587c0d71e38ed3 Author: Masahiro Masuda Date: Thu Dec 24 05:19:31 2020 +0900 add some comment to explain the pattern commit 52cea1cc2bff533ca60acfc2416477fc8b058428 Author: Masahiro Masuda Date: Wed Dec 23 08:35:14 2020 +0900 revert tutorial change commit d3e0e0d7e2427c40067d6ad2680ec5b3f0076223 Author: Masahiro Masuda Date: Wed Dec 23 08:02:29 2020 +0900 test fixed by setting force_surpress=False commit 2fa1a574f932001be2d8f601338a342dab92f79c Author: Masahiro Masuda Date: Wed Dec 23 07:22:32 2020 +0900 fixed coord_start commit 6ba88f27dec1bdb0b0ba746c268591a59264088e Author: Masahiro Masuda Date: Wed Dec 23 06:50:46 2020 +0900 add doc commit 8d386b6a1c92ce4fe3349ff20e320199a1b5b310 Author: Masahiro Masuda Date: Wed Dec 23 05:27:26 2020 +0900 updated tutorial commit 3206b49ecfdd874e0ff8feb0fa586c4c4282f705 Author: Masahiro Masuda Date: Wed Dec 23 05:04:44 2020 +0900 update object detection test to add rewrite commit 74bebb2f4376aeb67d8c4aad395f9f2661fe6b3e Author: Masahiro Masuda Date: Wed Dec 23 05:02:15 2020 +0900 add a pattern to rewrite nms to batched nms commit f410e6dde0ed949b90312c5a7ddbb6c234f9acc1 Author: Masahiro Masuda Date: Sat Dec 26 22:20:16 2020 +0900 add comment commit f1e078b0724bd22e7be0a812055e1c7c650d94da Author: Masahiro Masuda Date: Sat Dec 26 19:54:22 2020 +0900 Add if pattern --- include/tvm/relay/dataflow_pattern.h | 20 +++++++++ include/tvm/relay/dataflow_pattern_functor.h | 3 ++ python/tvm/relay/dataflow_pattern/__init__.py | 43 +++++++++++++++++++ src/relay/ir/dataflow_matcher.cc | 12 ++++++ src/relay/ir/dataflow_pattern.cc | 22 ++++++++++ src/relay/ir/dataflow_pattern_functor.cc | 6 +++ src/relay/ir/indexed_graph.cc | 6 +++ 7 files changed, 112 insertions(+) diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h index 5b2734f52ede..1b0c0aca7ff6 100644 --- a/include/tvm/relay/dataflow_pattern.h +++ b/include/tvm/relay/dataflow_pattern.h @@ -260,6 +260,26 @@ class TupleGetItemPatternNode : public DFPatternNode { TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode); }; +class IfPatternNode : public DFPatternNode { + public: + DFPattern cond, true_branch, false_branch; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("cond", &cond); + v->Visit("true_branch", &true_branch); + v->Visit("false_branch", &false_branch); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.IfPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(IfPatternNode, DFPatternNode); +}; + +class IfPattern : public DFPattern { + public: + TVM_DLL IfPattern(DFPattern cond, DFPattern then_clause, DFPattern else_clause); + TVM_DEFINE_OBJECT_REF_METHODS(IfPattern, DFPattern, IfPatternNode); +}; + class TupleGetItemPattern : public DFPattern { public: TVM_DLL TupleGetItemPattern(DFPattern tuple, int index); diff --git a/include/tvm/relay/dataflow_pattern_functor.h b/include/tvm/relay/dataflow_pattern_functor.h index f04977b86ccb..bff9e23ef046 100644 --- a/include/tvm/relay/dataflow_pattern_functor.h +++ b/include/tvm/relay/dataflow_pattern_functor.h @@ -91,6 +91,7 @@ class DFPatternFunctor { virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const TupleGetItemPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const IfPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; @@ -116,6 +117,7 @@ class DFPatternFunctor { RELAY_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(IfPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode); @@ -144,6 +146,7 @@ class DFPatternVisitor : public DFPatternFunctor { void VisitDFPattern_(const ShapePatternNode* op) override; void VisitDFPattern_(const TupleGetItemPatternNode* op) override; void VisitDFPattern_(const TuplePatternNode* op) override; + void VisitDFPattern_(const IfPatternNode* op) override; void VisitDFPattern_(const TypePatternNode* op) override; void VisitDFPattern_(const VarPatternNode* op) override; void VisitDFPattern_(const WildcardPatternNode* op) override; diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index f5161ad0bfa7..6f764e1651da 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -314,6 +314,29 @@ def is_tuple_get_item(tuple_value: "DFPattern", index: Optional[int] = None) -> return TupleGetItemPattern(tuple_value, index) +def is_if(cond, true_branch, false_branch): + """ + Syntatic sugar for creating an IfPattern. + + Parameters + ---------- + cond: tvm.relay.dataflow_pattern.DFPattern + The pattern describing the condition of If. + + true_branch: tvm.relay.dataflow_pattern.DFPattern + The pattern describing the true branch of If. + + false_branch: tvm.relay.dataflow_pattern.DFPattern + The pattern describing the false branch of If. + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting pattern. + """ + return IfPattern(cond, true_branch, false_branch) + + def wildcard() -> "DFPattern": """ Syntatic sugar for creating a WildcardPattern. @@ -536,6 +559,26 @@ def __init__( self.__init_handle_by_constructor__(ffi.FunctionPattern, params, body) +@register_df_node +class IfPattern(DFPattern): + """A patern matching a Relay If. + + Parameters + ---------- + cond: tvm.relay.dataflow_pattern.DFPattern + The pattern describing the condition of If. + + true_branch: tvm.relay.dataflow_pattern.DFPattern + The pattern describing the true branch of If. + + false_branch: tvm.relay.dataflow_pattern.DFPattern + The pattern describing the false branch of If. + """ + + def __init__(self, cond: "DFPattern", true_branch: "DFPattern", false_branch: "DFPattern"): + self.__init_handle_by_constructor__(ffi.IfPattern, cond, true_branch, false_branch) + + @register_df_node class TuplePattern(DFPattern): """A patern matching a Relay Tuple. diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index e4c0c7fa1c94..4b30ef1b6a37 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -58,6 +58,7 @@ class DFPatternMatcher : public DFPatternFunctor()) { + auto cond = if_node->cond; + auto true_branch = if_node->true_branch; + auto false_branch = if_node->false_branch; + return VisitDFPattern(op->cond, cond) && VisitDFPattern(op->true_branch, true_branch) && + VisitDFPattern(op->false_branch, false_branch); + } + return false; +} + Expr InferType(const Expr& expr) { auto mod = IRModule::FromExpr(expr); mod = transform::InferType()(mod); diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc index 086c3852b13f..1e268fb00d97 100644 --- a/src/relay/ir/dataflow_pattern.cc +++ b/src/relay/ir/dataflow_pattern.cc @@ -112,6 +112,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "FunctionPatternNode(" << node->params << ", " << node->body << ")"; }); +IfPattern::IfPattern(DFPattern cond, DFPattern true_branch, DFPattern false_branch) { + ObjectPtr n = make_object(); + n->cond = std::move(cond); + n->true_branch = std::move(true_branch); + n->false_branch = std::move(false_branch); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(IfPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.IfPattern") + .set_body_typed([](DFPattern cond, DFPattern true_branch, DFPattern false_branch) { + return IfPattern(cond, true_branch, false_branch); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "IfPattern(" << node->cond << ", " << node->true_branch << ", " + << node->false_branch << ")"; + }); + TuplePattern::TuplePattern(tvm::Array fields) { ObjectPtr n = make_object(); n->fields = std::move(fields); diff --git a/src/relay/ir/dataflow_pattern_functor.cc b/src/relay/ir/dataflow_pattern_functor.cc index aaa4f84b3254..25b247306229 100644 --- a/src/relay/ir/dataflow_pattern_functor.cc +++ b/src/relay/ir/dataflow_pattern_functor.cc @@ -81,6 +81,12 @@ void DFPatternVisitor::VisitDFPattern_(const TuplePatternNode* op) { } } +void DFPatternVisitor::VisitDFPattern_(const IfPatternNode* op) { + VisitDFPattern(op->cond); + VisitDFPattern(op->true_branch); + VisitDFPattern(op->false_branch); +} + void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPattern(op->pattern); } void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {} diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc index 4ba053c429de..3399ce30e678 100644 --- a/src/relay/ir/indexed_graph.cc +++ b/src/relay/ir/indexed_graph.cc @@ -282,6 +282,12 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { } } + void VisitDFPattern_(const IfPatternNode* op, NodePtr parent) override{ + VisitDFPattern(op->cond, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->true_branch, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->false_branch, graph_.node_map_[GetRef(op)]); + } + void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override { VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); } From b3baacccba53b20dae0d36f1c51949095c6a9672 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 14 Jan 2021 20:16:44 +0900 Subject: [PATCH 2/5] add doc --- docs/langref/relay_pattern.rst | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index ff02e50eb5fb..0bc112f33869 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -230,6 +230,21 @@ The next example is matching function nodes with a specific attribute: f = relay.Function([x, y], x + y).with_attr("Composite", "add") assert pattern.match(f) +Relay ``If`` expression can be matched if all of its condition, true branch and false branch +are matched: + +.. code-block:: python + + def test_match_if(): + x = is_var("x") + y = is_var("y") + pat = is_if(is_op("less")(x, y), x, y) + + x = relay.var("x") + y = relay.var("y") + cond = x < y + + assert pat.match(relay.expr.If(cond, x, y)) Matching Diamonds and Post-Dominator Graphs ******************************************* @@ -294,6 +309,7 @@ The high level design is to introduce a language of patterns for now we propose | is_op(op_name) | is_tuple() | is_tuple_get_item(pattern, index = None) + | is_if(cond, tru, fls) | pattern1 `|` pattern2 | dominates(parent_pattern, path_pattern, child_pattern) | FunctionPattern(params, body) From 85d3816c1c04c5175ea7e644a0181cd3ebfc4de8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 14 Jan 2021 20:21:10 +0900 Subject: [PATCH 3/5] add test --- tests/python/relay/test_dataflow_pattern.py | 38 +++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index f30a4e747c33..50015a480004 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -127,6 +127,17 @@ def test_AttrPattern(): assert op.attrs["TOpPattern"] == K_ELEMWISE +def test_IfPattern(): + x = is_var("x") + y = is_var("y") + pat = is_if(is_op("less")(x, y), x, y) + + assert isinstance(pat, IfPattern) + assert isinstance(pat.cond, CallPattern) + assert isinstance(pat.true_branch, VarPattern) + assert isinstance(pat.false_branch, VarPattern) + + ## MATCHER TESTS @@ -198,6 +209,30 @@ def test_no_match_func(): assert not func_pattern.match(relay.Function([x, y], x - y)) +def test_match_if(): + x = is_var("x") + y = is_var("y") + pat = is_if(is_op("less")(x, y), x, y) + + x = relay.var("x") + y = relay.var("y") + cond = x < y + + assert pat.match(relay.expr.If(cond, x, y)) + + +def test_no_match_if(): + x = is_var("x") + y = is_var("y") + pat = is_if(is_op("less")(x, y), x, y) + + x = relay.var("x") + y = relay.var("y") + + assert not pat.match(relay.expr.If(x > y, x, y)) + assert not pat.match(relay.expr.If(x < y, y, x)) + + def test_match_option(): x = relay.var("x") w = relay.var("w") @@ -1512,3 +1547,6 @@ def test_partition_constant_embedding(): test_partition_option() test_match_match() test_partition_constant_embedding() + test_IfPattern() + test_match_if() + test_no_match_if() From 0a5b976db15a5347e6cbdf7c187fcacdfef4cdd6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 14 Jan 2021 20:24:17 +0900 Subject: [PATCH 4/5] doc formatting --- docs/langref/relay_pattern.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index 0bc112f33869..992954c9a5b1 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -230,7 +230,7 @@ The next example is matching function nodes with a specific attribute: f = relay.Function([x, y], x + y).with_attr("Composite", "add") assert pattern.match(f) -Relay ``If`` expression can be matched if all of its condition, true branch and false branch +A Relay ``If`` expression can be matched if all of its condition, true branch and false branch are matched: .. code-block:: python @@ -309,7 +309,7 @@ The high level design is to introduce a language of patterns for now we propose | is_op(op_name) | is_tuple() | is_tuple_get_item(pattern, index = None) - | is_if(cond, tru, fls) + | is_if(cond, tru, fls) | pattern1 `|` pattern2 | dominates(parent_pattern, path_pattern, child_pattern) | FunctionPattern(params, body) From c8d26e5fc7c55b956bf8da50b6feaf25f5cda10a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 14 Jan 2021 20:28:30 +0900 Subject: [PATCH 5/5] cpplint fix --- src/relay/ir/indexed_graph.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc index 3399ce30e678..9ee5c9cf6b85 100644 --- a/src/relay/ir/indexed_graph.cc +++ b/src/relay/ir/indexed_graph.cc @@ -282,7 +282,7 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { } } - void VisitDFPattern_(const IfPatternNode* op, NodePtr parent) override{ + void VisitDFPattern_(const IfPatternNode* op, NodePtr parent) override { VisitDFPattern(op->cond, graph_.node_map_[GetRef(op)]); VisitDFPattern(op->true_branch, graph_.node_map_[GetRef(op)]); VisitDFPattern(op->false_branch, graph_.node_map_[GetRef(op)]);