Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion include/tvm/relay/dataflow_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/relay/dataflow_pattern.h>
#include <tvm/relay/dataflow_pattern_functor.h>

#include <string>
#include <unordered_map>
#include <utility>

Expand Down Expand Up @@ -87,10 +88,14 @@ Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr);
*
* \param pattern The pattern to match
* \param expr The expression to patition
* \param attrs A set of parameter names and values to apply to the partitioned function
* \param check A callback function for checking more complicated properties of the matched
* expressions, returns true if the match is accepted and false otherwise
*
* \return Return the paritioned Expr.
*/
Expr PartitionPattern(DFPattern pattern, Expr expr);
Expr PartitionPattern(DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs,
PackedFunc check);

} // namespace relay
} // namespace tvm
Expand Down
16 changes: 11 additions & 5 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def match(self, expr: Expr) -> bool:
"""
return match(self, expr)

def partition(self, expr: Expr, attrs=None) -> Expr:
def partition(self, expr: Expr, attrs=None, check=lambda x: True) -> Expr:
"""
Parition the expression into functions defined by this pattern

Expand All @@ -119,13 +119,16 @@ def partition(self, expr: Expr, attrs=None) -> Expr:
The expression to match.
attrs : Optional[Dict[str, Object]]
A dictionary of Attribute name/values to add to the paritioned function
check : Function
A function to perform more complicated checks on the matched expression.
Returns true if partitioning should proceed, false otherwise.

Returns
-------
result : tvm.relay.Expr
The Expression with matched subgraphs replaced by function calls to that subgraph
"""
return partition(self, expr, attrs)
return partition(self, expr, attrs, check)

def dominates(self, parent, path=None):
"""
Expand Down Expand Up @@ -561,7 +564,7 @@ def rewrite(callbacks, expr: Expr) -> Expr:

return ffi.rewrite(tmp, expr)

def partition(pattern: DFPattern, expr: Expr, attrs=None) -> Expr:
def partition(pattern: DFPattern, expr: Expr, attrs=None, check=lambda x: True) -> Expr:
"""
Parition the expression into a series of functions that match the pattern

Expand All @@ -571,12 +574,15 @@ def partition(pattern: DFPattern, expr: Expr, attrs=None) -> Expr:
The pattern to match
expr : tvm.relay.Expr
The expression to split into functions
expr : Optional[Dict[str, Object]]
attrs : Optional[Dict[str, Object]]
A dict of attributes to apply to the partitioned function
check : Function
A function to perform more complicated checks on the matched expression.
Returns true if partitioning should proceed, false otherwise.

Returns
-------
result : tvm.relay.Expr
The Expression with matched subgraphs replaced by function calls to that subgraph
"""
return ffi.partition(pattern, expr, attrs)
return ffi.partition(pattern, expr, attrs, check)
17 changes: 10 additions & 7 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -693,11 +693,12 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.rewrite").set_body_typed(RewritePatt
class PatternPartitioner : protected MixedModeMutator {
public:
Expr Partition(const DFPattern& pattern, const Expr& pre,
const Map<std::string, ObjectRef>& attrs) {
const Map<std::string, ObjectRef>& attrs, PackedFunc check) {
auto grouper = PatternGrouper();
groups_ = grouper.GroupMatches(pattern, pre);
gid_assignments_ = grouper.GetGIDAssignments();
attrs_ = attrs;
check_ = check;
return this->VisitExpr(pre);
}

Expand All @@ -718,7 +719,8 @@ class PatternPartitioner : protected MixedModeMutator {

Expr DispatchVisitExpr(const Expr& pre) override {
auto post = MixedModeMutator::DispatchVisitExpr(pre);
if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) {
if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node &&
static_cast<bool>(check_(pre))) {
post = RewritePartition(groups_[gid_assignments_[pre]]);
}
return post;
Expand All @@ -727,16 +729,17 @@ class PatternPartitioner : protected MixedModeMutator {
Map<std::string, ObjectRef> attrs_;
std::vector<PatternGrouper::Group> groups_;
std::unordered_map<Expr, int, ObjectHash, ObjectEqual> gid_assignments_;
PackedFunc check_;
};

Expr PartitionPattern(DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs) {
return PatternPartitioner().Partition(pattern, expr, attrs);
Expr PartitionPattern(DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs,
PackedFunc check) {
return PatternPartitioner().Partition(pattern, expr, attrs, check);
}

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.partition")
.set_body_typed([](DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs) {
return PartitionPattern(pattern, expr, attrs);
});
.set_body_typed([](DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs,
PackedFunc check) { return PartitionPattern(pattern, expr, attrs, check); });

} // namespace relay
} // namespace tvm
65 changes: 60 additions & 5 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
from tvm.relay.testing import run_opt_pass
import numpy as np

# NB: 1 corresponds to the C++ enum that specicfies this
Expand Down Expand Up @@ -880,7 +881,7 @@ def nested_diamond(inp, weight):
def get_BN(x, var, mean, beta, gamma, eps = 1e-5):
return gamma * (x - mean)/relay.op.sqrt(var + relay.const(eps)) + beta

def test_parition_batchnorm():
def test_partition_batchnorm():
x = relay.var('x')
var = relay.var('var')
mean = relay.var('mean')
Expand All @@ -900,7 +901,7 @@ def test_parition_batchnorm():
partitioned = BatchnormCallback().pattern.partition(BN)
assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, beta))

def test_parition_double_batchnorm():
def test_partition_double_batchnorm():
x = relay.var('x')
var = relay.var('var')
mean = relay.var('mean')
Expand All @@ -916,7 +917,7 @@ def test_parition_double_batchnorm():
betaf = relay.var('betaf')
gammaf = relay.var('gammaf')
f1 = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_")
# The paritioner doesn't replace duplicates, so we use two copies of the function
# The partitioner doesn't replace duplicates, so we use two copies of the function
xf2 = relay.var('xf2')
varf2 = relay.var('varf2')
meanf2 = relay.var('meanf2')
Expand All @@ -928,6 +929,58 @@ def test_parition_double_batchnorm():
reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta)
assert tvm.ir.structural_equal(partitioned, reference)

def test_partition_check():
pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
def check(pre):
return pre.args[0].attrs.data_layout == "NCHW"

x = relay.var('input')
w = relay.var('weight')
conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)

xf = relay.var('input')
wf = relay.var('weight')
conv2df = relay.op.nn.conv2d(xf, wf)
reluf = relay.op.nn.relu(conv2df)
func = relay.Function([xf, wf], reluf).with_attr("PartitionedFromPattern", "nn.conv2d_nn.relu_")

reference = func(x, w)
partitioned = pattern.partition(relu, check=check)
assert tvm.ir.structural_equal(partitioned, reference)

conv2d = relay.op.nn.conv2d(x, w, data_layout="NHWC")
relu = relay.op.nn.relu(conv2d)
assert relu == pattern.partition(relu, check=check)

def test_partition_check_types():
pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
def check(pre):
conv = pre.args[0]
return (conv.attrs.data_layout == "NCHW") and bool(conv.checked_type.shape[0] == 1)

x = relay.var('input', shape=(1, 10, 10, 10))
w = relay.var('weight', shape=(10, 10, 3, 3))
conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)
relu = run_opt_pass(relu, relay.transform.InferType())

partitioned = pattern.partition(relu, check=check)
assert partitioned.op.attrs["PartitionedFromPattern"] == "nn.conv2d_nn.relu_"

conv2d = relay.op.nn.conv2d(x, w, data_layout="NHWC")
relu = relay.op.nn.relu(conv2d)
relu = run_opt_pass(relu, relay.transform.InferType())
assert relu == pattern.partition(relu, check=check)

x = relay.var('input', shape=(2, 10, 10, 10))
w = relay.var('weight', shape=(10, 10, 3, 3))
conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)
relu = run_opt_pass(relu, relay.transform.InferType())
assert relu == pattern.partition(relu, check=check)


if __name__ == "__main__":
test_match_op()
test_no_match_op()
Expand Down Expand Up @@ -957,6 +1010,8 @@ def test_parition_double_batchnorm():
test_algebraic_simplify()
test_partition_dominator()
test_quadruple_partition_dominator()
test_parition_batchnorm()
test_parition_double_batchnorm()
test_partition_batchnorm()
test_partition_double_batchnorm()
test_partition_check()
test_partition_check_types()