diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 04c07c439cac..e89c5e44454f 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -1073,7 +1073,11 @@ class PatternBasedPartitioner : ExprVisitor { current_block_use_def_ = {}; } - void VisitVarDef(const Var& var) final { group_map_[var.get()] = arena_->make(); } + void VisitVarDef(const Var& var) final { + Group* g = arena_->make(); + group_map_[var.get()] = g; + vars_in_group_[g].push_back(var); + } void VisitBinding_(const VarBindingNode* binding) final { bindings_.Set(binding->var, binding->value); @@ -1097,7 +1101,13 @@ class PatternBasedPartitioner : ExprVisitor { auto g = GetGroup(match); if (g && g->FindRoot()->num_nodes > 1) { // This expression has already been matched to a previous pattern. - return; + // If the prior matched subgraph is subsumed by the new matched one, + // we can safely merge them, obtaining a maximized matched subgraph enventually. + // Otherwise, merging them will result in an incorrect subgraph, + // so we keep the prior subgraph and discard the current one by directly return. + auto vars_in_prior_matched_graph = vars_in_group_[g]; + if (!GraphSubsumedInMatchedValues(vars_in_prior_matched_graph, matches_opt.value())) + return; } } } @@ -1145,6 +1155,7 @@ class PatternBasedPartitioner : ExprVisitor { if (group_map_[e.get()] != to) { --group_map_[e.get()]->num_nodes; group_map_[e.get()]->parent = to; + vars_in_group_[to].push_back(e); ++to->num_nodes; } } @@ -1181,6 +1192,21 @@ class PatternBasedPartitioner : ExprVisitor { current_block_use_def_, value_to_bound_var_); } + // check if a previous matched subgraph is subsumed by the current matched result + bool GraphSubsumedInMatchedValues(const Array& vars_in_graph, + const Map& matched_result) { + std::set matched_vars; + for (const auto& [pat, match] : matched_result) { + if ((pat->IsInstance() || pat->IsInstance())) + matched_vars.insert(value_to_bound_var_[match]); + } + + for (const auto var : vars_in_graph) { + if (matched_vars.find(var) == matched_vars.end()) return false; + } + return true; + } + String pat_name_; DFPattern pat_; Map annotation_pat_; @@ -1191,6 +1217,7 @@ class PatternBasedPartitioner : ExprVisitor { Map value_to_bound_var_; Map> current_block_use_def_; GroupMap group_map_; + std::map> vars_in_group_; }; /*! diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index 5e700b277f32..f5905f764351 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -1217,5 +1217,31 @@ def inner_func( tvm.ir.assert_structural_equal(Expected, After) +def test_match_maximal_subgraph(): + @R.function + def func( + x: R.Tensor((32, 8), dtype="int32"), + y: R.Tensor((8, 8), dtype="int32"), + bias: R.Tensor((8,), dtype="int32"), + ) -> R.Tensor((32, 8), dtype="int32"): + R.func_attr({"global_symbol": "main"}) + with R.dataflow(): + lv0 = R.matmul(x, y, out_dtype="int32") + lv1 = R.add(lv0, bias) + lv2 = R.clip(lv1, -128, 127) + R.output(lv2) + return lv2 + + mod = tvm.IRModule({"main": func}) + + matmul = is_op("relax.matmul")(wildcard(), wildcard()) + matmul_add = is_op("relax.add")(matmul, wildcard()) + pattern = matmul_add | is_op("relax.clip")(matmul_add, wildcard(), wildcard()) + + partitioned = relax.transform.FuseOpsByPattern([("orclip", pattern)])(mod) + func_names = [name.name_hint for (name, _) in partitioned.functions.items()] + assert "fused_relax_matmul_relax_add_relax_clip" in func_names + + if __name__ == "__main__": pytest.main([__file__])