Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,11 @@ class PatternBasedPartitioner : ExprVisitor {
current_block_use_def_ = {};
}

void VisitVarDef(const Var& var) final { group_map_[var.get()] = arena_->make<Group>(); }
void VisitVarDef(const Var& var) final {
Group* g = arena_->make<Group>();
group_map_[var.get()] = g;
vars_in_group_[g].push_back(var);
}

void VisitBinding_(const VarBindingNode* binding) final {
bindings_.Set(binding->var, binding->value);
Expand All @@ -1097,7 +1101,13 @@ class PatternBasedPartitioner : ExprVisitor {
auto g = GetGroup(match);
if (g && g->FindRoot()->num_nodes > 1) {
// This expression has already been matched to a previous pattern.
return;
// If the prior matched subgraph is subsumed by the new matched one,
// we can safely merge them, obtaining a maximized matched subgraph enventually.
// Otherwise, merging them will result in an incorrect subgraph,
// so we keep the prior subgraph and discard the current one by directly return.
auto vars_in_prior_matched_graph = vars_in_group_[g];
if (!GraphSubsumedInMatchedValues(vars_in_prior_matched_graph, matches_opt.value()))
return;
}
}
}
Expand Down Expand Up @@ -1145,6 +1155,7 @@ class PatternBasedPartitioner : ExprVisitor {
if (group_map_[e.get()] != to) {
--group_map_[e.get()]->num_nodes;
group_map_[e.get()]->parent = to;
vars_in_group_[to].push_back(e);
++to->num_nodes;
}
}
Expand Down Expand Up @@ -1181,6 +1192,21 @@ class PatternBasedPartitioner : ExprVisitor {
current_block_use_def_, value_to_bound_var_);
}

// check if a previous matched subgraph is subsumed by the current matched result
bool GraphSubsumedInMatchedValues(const Array<Expr>& vars_in_graph,
const Map<DFPattern, Expr>& matched_result) {
std::set<Expr> matched_vars;
for (const auto& [pat, match] : matched_result) {
if ((pat->IsInstance<CallPatternNode>() || pat->IsInstance<TupleGetItemPatternNode>()))
matched_vars.insert(value_to_bound_var_[match]);
}

for (const auto var : vars_in_graph) {
if (matched_vars.find(var) == matched_vars.end()) return false;
}
return true;
}

String pat_name_;
DFPattern pat_;
Map<String, DFPattern> annotation_pat_;
Expand All @@ -1191,6 +1217,7 @@ class PatternBasedPartitioner : ExprVisitor {
Map<Expr, Var> value_to_bound_var_;
Map<Var, Array<Var>> current_block_use_def_;
GroupMap group_map_;
std::map<Group*, Array<Expr>> vars_in_group_;
};

/*!
Expand Down
26 changes: 26 additions & 0 deletions tests/python/relax/test_transform_fuse_ops_by_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,5 +1217,31 @@ def inner_func(
tvm.ir.assert_structural_equal(Expected, After)


def test_match_maximal_subgraph():
@R.function
def func(
x: R.Tensor((32, 8), dtype="int32"),
y: R.Tensor((8, 8), dtype="int32"),
bias: R.Tensor((8,), dtype="int32"),
) -> R.Tensor((32, 8), dtype="int32"):
R.func_attr({"global_symbol": "main"})
with R.dataflow():
lv0 = R.matmul(x, y, out_dtype="int32")
lv1 = R.add(lv0, bias)
lv2 = R.clip(lv1, -128, 127)
R.output(lv2)
return lv2

mod = tvm.IRModule({"main": func})

matmul = is_op("relax.matmul")(wildcard(), wildcard())
matmul_add = is_op("relax.add")(matmul, wildcard())
pattern = matmul_add | is_op("relax.clip")(matmul_add, wildcard(), wildcard())

partitioned = relax.transform.FuseOpsByPattern([("orclip", pattern)])(mod)
func_names = [name.name_hint for (name, _) in partitioned.functions.items()]
assert "fused_relax_matmul_relax_add_relax_clip" in func_names


if __name__ == "__main__":
pytest.main([__file__])