diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index ff02e50eb5fb..992954c9a5b1 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -230,6 +230,21 @@ The next example is matching function nodes with a specific attribute: f = relay.Function([x, y], x + y).with_attr("Composite", "add") assert pattern.match(f) +A Relay ``If`` expression can be matched if all of its condition, true branch and false branch +are matched: + +.. code-block:: python + + def test_match_if(): + x = is_var("x") + y = is_var("y") + pat = is_if(is_op("less")(x, y), x, y) + + x = relay.var("x") + y = relay.var("y") + cond = x < y + + assert pat.match(relay.expr.If(cond, x, y)) Matching Diamonds and Post-Dominator Graphs ******************************************* @@ -294,6 +309,7 @@ The high level design is to introduce a language of patterns for now we propose | is_op(op_name) | is_tuple() | is_tuple_get_item(pattern, index = None) + | is_if(cond, tru, fls) | pattern1 `|` pattern2 | dominates(parent_pattern, path_pattern, child_pattern) | FunctionPattern(params, body) diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h index 5b2734f52ede..1b0c0aca7ff6 100644 --- a/include/tvm/relay/dataflow_pattern.h +++ b/include/tvm/relay/dataflow_pattern.h @@ -260,6 +260,26 @@ class TupleGetItemPatternNode : public DFPatternNode { TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode); }; +class IfPatternNode : public DFPatternNode { + public: + DFPattern cond, true_branch, false_branch; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("cond", &cond); + v->Visit("true_branch", &true_branch); + v->Visit("false_branch", &false_branch); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.IfPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(IfPatternNode, DFPatternNode); +}; + +class IfPattern : public DFPattern { + public: + TVM_DLL IfPattern(DFPattern cond, DFPattern then_clause, DFPattern else_clause); + TVM_DEFINE_OBJECT_REF_METHODS(IfPattern, DFPattern, IfPatternNode); +}; + class TupleGetItemPattern : public DFPattern { public: TVM_DLL TupleGetItemPattern(DFPattern tuple, int index); diff --git a/include/tvm/relay/dataflow_pattern_functor.h b/include/tvm/relay/dataflow_pattern_functor.h index f04977b86ccb..bff9e23ef046 100644 --- a/include/tvm/relay/dataflow_pattern_functor.h +++ b/include/tvm/relay/dataflow_pattern_functor.h @@ -91,6 +91,7 @@ class DFPatternFunctor { virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const TupleGetItemPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const IfPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; @@ -116,6 +117,7 @@ class DFPatternFunctor { RELAY_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(IfPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode); @@ -144,6 +146,7 @@ class DFPatternVisitor : public DFPatternFunctor { void VisitDFPattern_(const ShapePatternNode* op) override; void VisitDFPattern_(const TupleGetItemPatternNode* op) override; void VisitDFPattern_(const TuplePatternNode* op) override; + void VisitDFPattern_(const IfPatternNode* op) override; void VisitDFPattern_(const TypePatternNode* op) override; void VisitDFPattern_(const VarPatternNode* op) override; void VisitDFPattern_(const WildcardPatternNode* op) override; diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index f5161ad0bfa7..6f764e1651da 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -314,6 +314,29 @@ def is_tuple_get_item(tuple_value: "DFPattern", index: Optional[int] = None) -> return TupleGetItemPattern(tuple_value, index) +def is_if(cond, true_branch, false_branch): + """ + Syntatic sugar for creating an IfPattern. + + Parameters + ---------- + cond: tvm.relay.dataflow_pattern.DFPattern + The pattern describing the condition of If. + + true_branch: tvm.relay.dataflow_pattern.DFPattern + The pattern describing the true branch of If. + + false_branch: tvm.relay.dataflow_pattern.DFPattern + The pattern describing the false branch of If. + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting pattern. + """ + return IfPattern(cond, true_branch, false_branch) + + def wildcard() -> "DFPattern": """ Syntatic sugar for creating a WildcardPattern. @@ -536,6 +559,26 @@ def __init__( self.__init_handle_by_constructor__(ffi.FunctionPattern, params, body) +@register_df_node +class IfPattern(DFPattern): + """A patern matching a Relay If. + + Parameters + ---------- + cond: tvm.relay.dataflow_pattern.DFPattern + The pattern describing the condition of If. + + true_branch: tvm.relay.dataflow_pattern.DFPattern + The pattern describing the true branch of If. + + false_branch: tvm.relay.dataflow_pattern.DFPattern + The pattern describing the false branch of If. + """ + + def __init__(self, cond: "DFPattern", true_branch: "DFPattern", false_branch: "DFPattern"): + self.__init_handle_by_constructor__(ffi.IfPattern, cond, true_branch, false_branch) + + @register_df_node class TuplePattern(DFPattern): """A patern matching a Relay Tuple. diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index e4c0c7fa1c94..4b30ef1b6a37 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -58,6 +58,7 @@ class DFPatternMatcher : public DFPatternFunctor()) { + auto cond = if_node->cond; + auto true_branch = if_node->true_branch; + auto false_branch = if_node->false_branch; + return VisitDFPattern(op->cond, cond) && VisitDFPattern(op->true_branch, true_branch) && + VisitDFPattern(op->false_branch, false_branch); + } + return false; +} + Expr InferType(const Expr& expr) { auto mod = IRModule::FromExpr(expr); mod = transform::InferType()(mod); diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc index 086c3852b13f..1e268fb00d97 100644 --- a/src/relay/ir/dataflow_pattern.cc +++ b/src/relay/ir/dataflow_pattern.cc @@ -112,6 +112,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "FunctionPatternNode(" << node->params << ", " << node->body << ")"; }); +IfPattern::IfPattern(DFPattern cond, DFPattern true_branch, DFPattern false_branch) { + ObjectPtr n = make_object(); + n->cond = std::move(cond); + n->true_branch = std::move(true_branch); + n->false_branch = std::move(false_branch); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(IfPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.IfPattern") + .set_body_typed([](DFPattern cond, DFPattern true_branch, DFPattern false_branch) { + return IfPattern(cond, true_branch, false_branch); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "IfPattern(" << node->cond << ", " << node->true_branch << ", " + << node->false_branch << ")"; + }); + TuplePattern::TuplePattern(tvm::Array fields) { ObjectPtr n = make_object(); n->fields = std::move(fields); diff --git a/src/relay/ir/dataflow_pattern_functor.cc b/src/relay/ir/dataflow_pattern_functor.cc index aaa4f84b3254..25b247306229 100644 --- a/src/relay/ir/dataflow_pattern_functor.cc +++ b/src/relay/ir/dataflow_pattern_functor.cc @@ -81,6 +81,12 @@ void DFPatternVisitor::VisitDFPattern_(const TuplePatternNode* op) { } } +void DFPatternVisitor::VisitDFPattern_(const IfPatternNode* op) { + VisitDFPattern(op->cond); + VisitDFPattern(op->true_branch); + VisitDFPattern(op->false_branch); +} + void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPattern(op->pattern); } void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {} diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc index 4ba053c429de..9ee5c9cf6b85 100644 --- a/src/relay/ir/indexed_graph.cc +++ b/src/relay/ir/indexed_graph.cc @@ -282,6 +282,12 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { } } + void VisitDFPattern_(const IfPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->cond, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->true_branch, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->false_branch, graph_.node_map_[GetRef(op)]); + } + void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override { VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); } diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index f30a4e747c33..50015a480004 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -127,6 +127,17 @@ def test_AttrPattern(): assert op.attrs["TOpPattern"] == K_ELEMWISE +def test_IfPattern(): + x = is_var("x") + y = is_var("y") + pat = is_if(is_op("less")(x, y), x, y) + + assert isinstance(pat, IfPattern) + assert isinstance(pat.cond, CallPattern) + assert isinstance(pat.true_branch, VarPattern) + assert isinstance(pat.false_branch, VarPattern) + + ## MATCHER TESTS @@ -198,6 +209,30 @@ def test_no_match_func(): assert not func_pattern.match(relay.Function([x, y], x - y)) +def test_match_if(): + x = is_var("x") + y = is_var("y") + pat = is_if(is_op("less")(x, y), x, y) + + x = relay.var("x") + y = relay.var("y") + cond = x < y + + assert pat.match(relay.expr.If(cond, x, y)) + + +def test_no_match_if(): + x = is_var("x") + y = is_var("y") + pat = is_if(is_op("less")(x, y), x, y) + + x = relay.var("x") + y = relay.var("y") + + assert not pat.match(relay.expr.If(x > y, x, y)) + assert not pat.match(relay.expr.If(x < y, y, x)) + + def test_match_option(): x = relay.var("x") w = relay.var("w") @@ -1512,3 +1547,6 @@ def test_partition_constant_embedding(): test_partition_option() test_match_match() test_partition_constant_embedding() + test_IfPattern() + test_match_if() + test_no_match_if()