diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 0d9481312137..cfacd41487c8 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -730,7 +730,8 @@ class PatternGrouper { auto node = matcher_->expr_graph_.node_map_.at(kv.first); for (auto* output : node->outputs_) { // and the node is used by nodes outside of the group - if (memo.count(output->ref_) == 0) { + if (memo.count(output->ref_) == 0 && + !matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) { // Exit because nodes in this pattern's body are used outside the pattern // fusing it would be invalid return; diff --git a/src/relay/ir/indexed_graph.h b/src/relay/ir/indexed_graph.h index 4bbb741b760d..d073bcaeea5c 100644 --- a/src/relay/ir/indexed_graph.h +++ b/src/relay/ir/indexed_graph.h @@ -27,6 +27,7 @@ #include #include +#include #include #include #include @@ -74,6 +75,27 @@ class IndexedGraph { Node* dominator_parent_; /*! \brief The nodes this node dominates */ std::vector dominator_children_; + + bool Dominates(const Node* other) { + std::stack stack; + std::unordered_set visited; + stack.push(this); + while (!stack.empty()) { + const Node* current = stack.top(); + stack.pop(); + for (auto node : current->dominator_children_) { + if (visited.count(node) == 0) { + if (other == node) { + return true; + } else { + stack.push(node); + } + visited.insert(node); + } + } + } + return false; + } }; /*! \brief Construct the domination tree inside IndexedGraph */ void PostDom() { diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index e7b367b8f631..15d3ee035450 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=unused-wildcard-import import numpy as np +import pytest import tvm from tvm import relay @@ -1470,6 +1471,76 @@ def test_partition_function(): assert tvm.ir.structural_equal(pattern.partition(expr), expr2) +def test_rewrite_function_with_fuzzy_body(): + """Allow Rewriting a function with a fuzzy body via dominator analysis""" + x = relay.var("x") + w = relay.var("w") + b = relay.var("b") + + x1 = relay.var("x1") + w1 = relay.var("w1") + + wc_x = wildcard() + wc_w = wildcard() + wc_b = wildcard() + wc_x1 = wildcard() + wc_w1 = wildcard() + + func_pattern = FunctionPattern([wc_x1, wc_w1], wildcard()) + pattern = func_pattern(wc_x, wc_w) + wc_b + + func = relay.Function([x1, w1], relay.nn.conv2d(x1, w1)) + expr = func(x, w) + b + b + + class TestRewrite(DFPatternCallback): + def __init__(self): + super(TestRewrite, self).__init__() + self.pattern = pattern + + def callback(self, pre, post, node_map): + return x + w + + out = rewrite(TestRewrite(), expr) + assert tvm.ir.structural_equal(x + w, x + w) + + +@pytest.mark.skip( + """TODO(mbrookhart): The current partitioner can't properly handle + the partitioned inputs on the fuzzy body""" +) +def test_partition_function_with_fuzzy_body(): + """ + Allow Rewriting a function with a fuzzy body via dominator analysis + """ + x = relay.var("x") + w = relay.var("w") + b = relay.var("b") + + x1 = relay.var("x1") + w1 = relay.var("w1") + + wc_x = wildcard() + wc_w = wildcard() + wc_b = wildcard() + wc_x1 = wildcard() + wc_w1 = wildcard() + + func_pattern = FunctionPattern([wc_x1, wc_w1], wildcard()) + pattern = func_pattern(wc_x, wc_w) + wc_b + + func = relay.Function([x1, w1], relay.nn.conv2d(x1, w1)) + expr = func(x, w) + b + b + + x2 = relay.var("x2") + w2 = relay.var("w2") + b2 = relay.var("b2") + func2 = relay.Function([x2, w2, b2], func(x2, w2) + b2).with_attr( + "PartitionedFromPattern", "FunctionCall_add_" + ) + expr2 = func2(x, w, b) + b + assert tvm.ir.structural_equal(pattern.partition(expr), expr2) + + def test_match_match(): add_pattern = is_op("add")(wildcard(), wildcard())