diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 290ee42effe3..1a1cfa9d23e3 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -831,10 +831,19 @@ class PatternRewriter : ExprMutator { } BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) final { - if (!ctx_) { - return ExprMutator::VisitBindingBlock_(block_node); + if (ctx_) { + return RewriteDataflowBlockFixedPoint(GetRef(block_node)); + } + + DataflowBlock prev = GetRef(block_node); + while (true) { + DataflowBlock next = Downcast(ExprMutator::VisitBindingBlock_(prev.get())); + if (StructuralEqual()(prev, next)) { + return std::move(next); + } else { + prev = next; + } } - return RewriteDataflowBlockFixedPoint(GetRef(block_node)); } private: diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 202db9b5b3d1..a804ea46c6d6 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -1318,5 +1318,53 @@ def rewriter(_expr, matches): tvm.ir.assert_structural_equal(after, expected) +def test_repeated_pattern_match(): + """rewrite_call should iterate until convergence""" + + @R.function(private=True) + def before( + x: R.Tensor((1024,)), + y: R.Tensor((1024,)), + z: R.Tensor((1024,)), + ): + with R.dataflow(): + a = R.add(x, y) + b = R.add(a, z) + out = R.multiply(b, R.const(5.0)) + R.output(out) + return out + + @R.function(private=True) + def expected( + x: R.Tensor((1024,)), + y: R.Tensor((1024,)), + z: R.Tensor((1024,)), + ): + with R.dataflow(): + x = R.multiply(x, R.const(5.0)) + y = R.multiply(y, R.const(5.0)) + a = R.add(x, y) + z = R.multiply(z, R.const(5.0)) + b = R.add(a, z) + R.output(b) + return b + + pattern_add_lhs = wildcard() + pattern_add_rhs = wildcard() + pattern_add = is_op("relax.add")(pattern_add_lhs, pattern_add_rhs) + + mul_const = is_const() + pattern_mul = is_op("relax.multiply")(pattern_add, mul_const) + + pattern = pattern_mul + + def rewriter(_expr, matches): + const = matches[mul_const] + return (matches[pattern_add_lhs] * const) + (matches[pattern_add_rhs] * const) + + after = rewrite_call(pattern, rewriter, before) + tvm.ir.assert_structural_equal(after, expected) + + if __name__ == "__main__": tvm.testing.main()