Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,24 @@ are matched:

assert pat.match(relay.expr.If(cond, x, y))


A Relay ``Let`` expression can be matched if all of its variable, value, and body
are matched:

.. code-block:: python

def test_match_let():
x = is_var("x")
y = is_var("y")
let_var = is_var("let")
pat = is_let(let_var, is_op("less")(x, y), let_var)

x = relay.var("x")
y = relay.var("y")
lv = relay.var("let")
cond = x < y
assert pat.match(relay.expr.Let(lv, cond, lv))

Matching Diamonds and Post-Dominator Graphs
*******************************************

Expand Down Expand Up @@ -310,6 +328,7 @@ The high level design is to introduce a language of patterns for now we propose
| is_tuple()
| is_tuple_get_item(pattern, index = None)
| is_if(cond, tru, fls)
| is_let(var, value, body)
| pattern1 `|` pattern2
| dominates(parent_pattern, path_pattern, child_pattern)
| FunctionPattern(params, body)
Expand Down Expand Up @@ -367,6 +386,16 @@ Function Pattern

Match a Function with a body and parameters

If Pattern
**********

Match an If with condition, true branch, and false branch

Let Pattern
***********

Match a Let with a variable, value, and body

Applications
============

Expand Down
36 changes: 36 additions & 0 deletions include/tvm/relay/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,42 @@ class FunctionPattern : public DFPattern {
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionPatternNode);
};

/*! \brief A binding of a sub-network. */
class LetPatternNode : public DFPatternNode {
public:
/*! \brief The variable we bind to */
DFPattern var;
/*! \brief The value we bind var to */
DFPattern value;
/*! \brief The body of the let binding */
DFPattern body;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("body", &body);
}

static constexpr const char* _type_key = "relay.dataflow_pattern.LetPattern";
TVM_DECLARE_FINAL_OBJECT_INFO(LetPatternNode, DFPatternNode);
};

/*!
* \brief Let binding that binds a local var
*/
class LetPattern : public DFPattern {
public:
/*!
* \brief The constructor
* \param var The variable that is bound to.
* \param value The value used to bind to the variable.
* \param body The body of the let binding.
*/
TVM_DLL LetPattern(DFPattern var, DFPattern value, DFPattern body);

TVM_DEFINE_OBJECT_REF_METHODS(LetPattern, DFPattern, LetPatternNode);
};

/*! \brief Tuple of multiple Exprs */
class TuplePattern;
/*! \brief Tuple container */
Expand Down
11 changes: 7 additions & 4 deletions include/tvm/relay/dataflow_pattern_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,19 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const FunctionPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const IfPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const LetPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const IfPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPatternDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
Expand All @@ -115,9 +116,10 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(IfPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(LetPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(IfPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
Expand All @@ -143,10 +145,11 @@ class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
void VisitDFPattern_(const DominatorPatternNode* op) override;
void VisitDFPattern_(const ExprPatternNode* op) override;
void VisitDFPattern_(const FunctionPatternNode* op) override;
void VisitDFPattern_(const IfPatternNode* op) override;
void VisitDFPattern_(const LetPatternNode* op) override;
void VisitDFPattern_(const ShapePatternNode* op) override;
void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
void VisitDFPattern_(const TuplePatternNode* op) override;
void VisitDFPattern_(const IfPatternNode* op) override;
void VisitDFPattern_(const TypePatternNode* op) override;
void VisitDFPattern_(const VarPatternNode* op) override;
void VisitDFPattern_(const WildcardPatternNode* op) override;
Expand Down
44 changes: 44 additions & 0 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,29 @@ def is_if(cond, true_branch, false_branch):
return IfPattern(cond, true_branch, false_branch)


def is_let(var, value, body):
"""
Syntatic sugar for creating a LetPattern.

Parameters
----------
var: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the variable of Let.

value: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the value of Let.

body: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the body where the binding is in effect.

Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
The resulting pattern.
"""
return LetPattern(var, value, body)


def wildcard() -> "DFPattern":
"""
Syntatic sugar for creating a WildcardPattern.
Expand Down Expand Up @@ -579,6 +602,27 @@ def __init__(self, cond: "DFPattern", true_branch: "DFPattern", false_branch: "D
self.__init_handle_by_constructor__(ffi.IfPattern, cond, true_branch, false_branch)


@register_df_node
class LetPattern(DFPattern):
"""A patern matching a Relay Let.

Parameters
----------
var: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the variable of Let.

value: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the value of Let.

body: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the body where the binding is in effect.

"""

def __init__(self, var: "DFPattern", value: "DFPattern", body: "DFPattern"):
self.__init_handle_by_constructor__(ffi.LetPattern, var, value, body)


@register_df_node
class TuplePattern(DFPattern):
"""A patern matching a Relay Tuple.
Expand Down
11 changes: 10 additions & 1 deletion src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const LetPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
Expand Down Expand Up @@ -423,6 +424,14 @@ bool DFPatternMatcher::VisitDFPattern_(const IfPatternNode* op, const Expr& expr
return false;
}

bool DFPatternMatcher::VisitDFPattern_(const LetPatternNode* op, const Expr& expr) {
if (const auto* let_node = expr.as<LetNode>()) {
return VisitDFPattern(op->var, let_node->var) && VisitDFPattern(op->value, let_node->value) &&
VisitDFPattern(op->body, let_node->body);
}
return false;
}

Expr InferType(const Expr& expr) {
auto mod = IRModule::FromExpr(expr);
mod = transform::InferType()(mod);
Expand Down
22 changes: 22 additions & 0 deletions src/relay/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "FunctionPatternNode(" << node->params << ", " << node->body << ")";
});

LetPattern::LetPattern(DFPattern var, DFPattern value, DFPattern body) {
ObjectPtr<LetPatternNode> n = make_object<LetPatternNode>();
n->var = std::move(var);
n->value = std::move(value);
n->body = std::move(body);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(LetPatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.LetPattern")
.set_body_typed([](DFPattern var, DFPattern value, DFPattern body) {
return LetPattern(var, value, body);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LetPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const LetPatternNode*>(ref.get());
p->stream << "LetPatternNode(" << node->var << ", " << node->value << ", " << node->body
<< ")";
});

IfPattern::IfPattern(DFPattern cond, DFPattern true_branch, DFPattern false_branch) {
ObjectPtr<IfPatternNode> n = make_object<IfPatternNode>();
n->cond = std::move(cond);
Expand Down
6 changes: 6 additions & 0 deletions src/relay/ir/dataflow_pattern_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ void DFPatternVisitor::VisitDFPattern_(const IfPatternNode* op) {
VisitDFPattern(op->false_branch);
}

void DFPatternVisitor::VisitDFPattern_(const LetPatternNode* op) {
VisitDFPattern(op->var);
VisitDFPattern(op->value);
VisitDFPattern(op->body);
}

void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPattern(op->pattern); }

void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {}
Expand Down
6 changes: 6 additions & 0 deletions src/relay/ir/indexed_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,12 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {
VisitDFPattern(op->false_branch, graph_.node_map_[GetRef<DFPattern>(op)]);
}

void VisitDFPattern_(const LetPatternNode* op, NodePtr parent) override {
VisitDFPattern(op->var, graph_.node_map_[GetRef<DFPattern>(op)]);
VisitDFPattern(op->value, graph_.node_map_[GetRef<DFPattern>(op)]);
VisitDFPattern(op->body, graph_.node_map_[GetRef<DFPattern>(op)]);
}

void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override {
VisitDFPattern(op->pattern, graph_.node_map_[GetRef<DFPattern>(op)]);
}
Expand Down
39 changes: 39 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,18 @@ def test_IfPattern():
assert isinstance(pat.false_branch, VarPattern)


def test_LetPattern():
x = is_var("x")
y = is_var("y")
let_var = is_var("let")
pat = is_let(let_var, is_op("less")(x, y), let_var)

assert isinstance(pat, LetPattern)
assert isinstance(pat.var, VarPattern)
assert isinstance(pat.value, CallPattern)
assert isinstance(pat.body, VarPattern)


## MATCHER TESTS


Expand Down Expand Up @@ -233,6 +245,33 @@ def test_no_match_if():
assert not pat.match(relay.expr.If(x < y, y, x))


def test_match_let():
x = is_var("x")
y = is_var("y")
let_var = is_var("let")
pat = is_let(let_var, is_op("less")(x, y), let_var)

x = relay.var("x")
y = relay.var("y")
lv = relay.var("let")
cond = x < y
assert pat.match(relay.expr.Let(lv, cond, lv))


def test_no_match_let():
x = is_var("x")
y = is_var("y")
let_var = is_var("let")
pat = is_let(let_var, is_op("less")(x, y), let_var)

x = relay.var("x")
y = relay.var("y")
lv = relay.var("let")

assert not pat.match(relay.expr.Let(lv, x > y, lv))
assert not pat.match(relay.expr.Let(lv, x < y, lv * x))


def test_match_option():
x = relay.var("x")
w = relay.var("w")
Expand Down