Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,8 @@ class PatternGrouper {
auto node = matcher_->expr_graph_.node_map_.at(kv.first);
for (auto* output : node->outputs_) {
// and the node is used by nodes outside of the group
if (memo.count(output->ref_) == 0) {
if (memo.count(output->ref_) == 0 &&
!matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) {
// Exit because nodes in this pattern's body are used outside the pattern
// fusing it would be invalid
return;
Expand Down
22 changes: 22 additions & 0 deletions src/relay/ir/indexed_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/relay/dataflow_pattern.h>

#include <memory>
#include <stack>
#include <unordered_map>
#include <unordered_set>
#include <utility>
Expand Down Expand Up @@ -74,6 +75,27 @@ class IndexedGraph {
Node* dominator_parent_;
/*! \brief The nodes this node dominates */
std::vector<Node*> dominator_children_;

bool Dominates(const Node* other) {
std::stack<const Node*> stack;
std::unordered_set<const Node*> visited;
stack.push(this);
while (!stack.empty()) {
const Node* current = stack.top();
stack.pop();
for (auto node : current->dominator_children_) {
if (visited.count(node) == 0) {
if (other == node) {
return true;
} else {
stack.push(node);
}
visited.insert(node);
}
}
}
return false;
}
};
/*! \brief Construct the domination tree inside IndexedGraph */
void PostDom() {
Expand Down
71 changes: 71 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=unused-wildcard-import
import numpy as np
import pytest

import tvm
from tvm import relay
Expand Down Expand Up @@ -1470,6 +1471,76 @@ def test_partition_function():
assert tvm.ir.structural_equal(pattern.partition(expr), expr2)


def test_rewrite_function_with_fuzzy_body():
"""Allow Rewriting a function with a fuzzy body via dominator analysis"""
x = relay.var("x")
w = relay.var("w")
b = relay.var("b")

x1 = relay.var("x1")
w1 = relay.var("w1")

wc_x = wildcard()
wc_w = wildcard()
wc_b = wildcard()
wc_x1 = wildcard()
wc_w1 = wildcard()

func_pattern = FunctionPattern([wc_x1, wc_w1], wildcard())
pattern = func_pattern(wc_x, wc_w) + wc_b

func = relay.Function([x1, w1], relay.nn.conv2d(x1, w1))
expr = func(x, w) + b + b

class TestRewrite(DFPatternCallback):
def __init__(self):
super(TestRewrite, self).__init__()
self.pattern = pattern

def callback(self, pre, post, node_map):
return x + w

out = rewrite(TestRewrite(), expr)
assert tvm.ir.structural_equal(x + w, x + w)


@pytest.mark.skip(
"""TODO(mbrookhart): The current partitioner can't properly handle
the partitioned inputs on the fuzzy body"""
)
def test_partition_function_with_fuzzy_body():
"""
Allow Rewriting a function with a fuzzy body via dominator analysis
"""
x = relay.var("x")
w = relay.var("w")
b = relay.var("b")

x1 = relay.var("x1")
w1 = relay.var("w1")

wc_x = wildcard()
wc_w = wildcard()
wc_b = wildcard()
wc_x1 = wildcard()
wc_w1 = wildcard()

func_pattern = FunctionPattern([wc_x1, wc_w1], wildcard())
pattern = func_pattern(wc_x, wc_w) + wc_b

func = relay.Function([x1, w1], relay.nn.conv2d(x1, w1))
expr = func(x, w) + b + b

x2 = relay.var("x2")
w2 = relay.var("w2")
b2 = relay.var("b2")
func2 = relay.Function([x2, w2, b2], func(x2, w2) + b2).with_attr(
"PartitionedFromPattern", "FunctionCall_add_"
)
expr2 = func2(x, w, b) + b
assert tvm.ir.structural_equal(pattern.partition(expr), expr2)


def test_match_match():
add_pattern = is_op("add")(wildcard(), wildcard())

Expand Down