Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions include/tvm/relax/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,10 @@ class PatternContextNode : public Object {
kMustNot, /*!< All nodes except outputs only have internal depedencies in the matched graph. */
} allow_extern_use = kMay;
// src node -> <dst node, constraint type> constraints.
std::map<DFPattern, std::map<DFPattern, std::vector<PairCons>>> constraints;
// Dst nodes are kept in a vector to keep them ordered.
std::map<DFPattern, std::vector<std::pair<DFPattern, std::vector<PairCons>>>> constraints;
// Keep a separate vector of patterns to process constraints in a fixed order.
std::vector<DFPattern> src_ordered;

static constexpr const char* _type_key = "relax.dpl.PatternContext";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextNode, Object);
Expand Down Expand Up @@ -224,9 +227,22 @@ class PatternContext : public ObjectRef {
* \param cons The constraint type. \sa PairCons
*/
void add_constraint(DFPattern producer, DFPattern consumer, PairCons cons) {
auto& vec = (*this)->constraints[producer][consumer];
ICHECK(std::find(vec.cbegin(), vec.cend(), cons) == vec.cend()) << "Constraint already exists";
vec.push_back(cons);
auto& pairs = (*this)->constraints[producer];
auto it = std::find_if(pairs.begin(), pairs.end(),
[consumer](auto p) { return p.first == consumer; });
if (it == pairs.end()) {
pairs.emplace_back(consumer, std::vector{cons});
} else {
auto& vec = it->second;
ICHECK(std::find(vec.cbegin(), vec.cend(), cons) == vec.cend())
<< "Constraint already exists";
vec.push_back(cons);
}

auto& patterns = (*this)->src_ordered;
if (std::find(patterns.begin(), patterns.end(), producer) == patterns.end()) {
patterns.push_back(producer);
}
}

/*! \brief Get the pass context object on the top of the stack */
Expand Down
104 changes: 54 additions & 50 deletions src/relax/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -541,15 +541,21 @@ struct RNode {
* \brief This method try to match a real node and a pattern node along with its neighbors.
*/
static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m,
const std::map<const VarNode*, std::set<const VarNode*>>& def2use,
const std::map<const VarNode*, std::vector<const VarNode*>>& def2use,
const std::map<const VarNode*, std::vector<const VarNode*>>& use2def) {
if (nullptr != p->matched && p->matched == r->ptr) return true; // matched before.
if (p->matched != nullptr && p->matched == r->ptr) return true; // matched before.
if (!m->Match(GetRef<DFPattern>(p->ptr), GetRef<Var>(r->ptr))) return false;

std::stack<std::pair<PNode*, RNode*>> undo_stack{};

const auto commit = [&undo_stack](PNode* p, RNode* r) {
// match with each other.
// TODO(ganler, masahi): Why commit on the same p-r pair happens more than once?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess in either way the eventual results are the same. But yeah doing a pre-check could be faster (avoid the overhead of undo_stack.emplace(p, r)).

Copy link
Member Author

@masahi masahi Mar 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here my concern is not about performance. If the same pair is committed more than once, I think there is something odd about the matching algorithm.

Later I'll try to improve the matching algorithm implementation. In particular, I want to remove try_match loop on parents,

for (auto& [pparent, constraints] : p->parents) {
bool any_cons_sat = false;
for (auto& rparent : r->parents) {
// skip if mismatch.
if (rparent->matched && rparent->matched != pparent->ptr) continue;
const auto& uses = def2use.at(rparent->ptr);
// check edge constraints.
bool cons_sat = true;
for (const auto& cons : constraints) {
if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) {
cons_sat = false;
break;
}
if (cons.index != -1) {
const auto& callees = use2def.at(r->ptr);
if (callees.size() <= static_cast<size_t>(cons.index) ||
callees[cons.index] != rparent->ptr) {
cons_sat = false;
break;
}
}
}
if (!cons_sat) continue;
any_cons_sat = true;
// try all parent R nodes that are not matched yet.
// as long as ppattern can match one node.
if (!pparent->matched && try_match(pparent, rparent, m, def2use, use2def)) {
. This bidirectional matching has been a source of confusion to me (and debugging is very hard) - Since the constraint graph is assumed to be a DAG, I think we can start matching from the root nodes in a topo-sorted order, and matching should be able to proceed purely in one direction.

Copy link
Contributor

@ganler ganler Mar 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Because the constraint or relation between nodes may or may not be single-way so in the beginning I made it bidirectional such that the pattern can be matched as long as you can let any node of the matched subgraph be start hint.

For example, for A->B pattern, you can start matching from either A or B (forward or backward).

Specifying certain matching order definitely makes the code logic and debugging easier but I am afraid it also cuts the flexibility in some way. Not sure if it is worth to have some flexibility or if we can keep them both.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an example in test_dataflow_pattern.py that requires such flexibility? Similarly to how I don't understand the need for start_hint, I don't understand why we might want to start matching from B in A -> B.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also going to look at how similar projects like MLIR, OpenVINO etc implements general graph pattern matching.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I need some time to illustrate but my flight is taking off now. I will get you back probably after 8 hours. 🥲

Copy link
Member Author

@masahi masahi Mar 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok :) I'll share my thought in the meantime.

Looking at an example in https://github.com/apache/tvm/blob/unity/tests/python/relax/test_dataflow_pattern.py#L540-L545, I cannot imagine how one would use start_hint in practice. Here, you wrote the input mod by hand, so you know dfb.bindings[0].var is associated to the "left" branch of the CBRx2. But in general, we don't have such information, especially in e2e scenarios for real world models.

Even if there was a good use case for it, I claim that it doesn't justify making API and the implementation more complicated (two additional params + try_match loop on parants). If one has such advanced knowledge of model structure and variables, they may as well have a different way to match & extract such subgraph.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for bringing it up. When designing the usages it was my bad that I did not know enough use cases and I made it "more capable" by compromising simplicity. masa you have more professional experience with how patterns look in practice so I strongly agree that we should cut off the implementation/complexity for such long-tail use cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, also to explain why I was using the code style of CONSTANT logic_op VAR. This is just a practice to avoid writing VAR = CONSTANT which causes silent logic errors. But yeah it is nice to normalize the code style.

if (p->ptr == r->matched) {
ICHECK_EQ(p->matched, r->ptr);
return;
}
ICHECK(r->matched == nullptr);
p->matched = r->ptr;
r->matched = p->ptr;
undo_stack.emplace(p, r);
Expand All @@ -568,31 +574,26 @@ static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m,
commit(p, r);

// match parent patterns.
for (auto& pparent_pairs : p->parents) {
PNode* pparent = pparent_pairs.first;
const std::vector<PairCons>& constraints = pparent_pairs.second;

for (auto& [pparent, constraints] : p->parents) {
bool any_cons_sat = false;
for (auto& rparent : r->parents) {
// skip if mismatch.
if (rparent->matched && rparent->matched != pparent->ptr) continue;

const auto& uses = def2use.at(rparent->ptr);
// skip if `rparent` is not used by `r`.
if (uses.cend() == uses.find(r->ptr)) continue;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been removed since it seems to be always false.


// check edge constraints.
bool cons_sat = true;
for (const auto& cons : constraints) {
if (PairCons::kOnlyUsedBy == cons.type && uses.size() != 1) {
if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) {
cons_sat = false;
break;
}

if (-1 != cons.index) {
if (cons.index != -1) {
const auto& callees = use2def.at(r->ptr);
if (static_cast<size_t>(cons.index) >= callees.size() ||
rparent->ptr != callees[cons.index]) {
if (callees.size() <= static_cast<size_t>(cons.index) ||
callees[cons.index] != rparent->ptr) {
cons_sat = false;
break;
}
Expand All @@ -612,27 +613,24 @@ static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m,
}

// forward matching;
for (auto& pchild_pairs : p->children) {
PNode* pchild = pchild_pairs.first;
const std::vector<PairCons>& constraints = pchild_pairs.second;
for (auto& [pchild, constraints] : p->children) {
bool any_cons_sat = false;
for (auto& rchild : r->children) {
if (rchild->matched && rchild->matched != pchild->ptr) continue;

const auto& uses = def2use.at(r->ptr);
if (uses.cend() == uses.find(rchild->ptr)) continue;

// check edge constraints.
bool all_cons_pass = true;
for (const auto& cons : constraints) {
if (PairCons::kOnlyUsedBy == cons.type && uses.size() != 1) {
if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) {
all_cons_pass = false;
break;
}

if (-1 != cons.index) {
if (cons.index != -1) {
const auto& callees = use2def.at(rchild->ptr);
if (static_cast<size_t>(cons.index) >= callees.size() || r->ptr != callees[cons.index]) {
if (callees.size() <= static_cast<size_t>(cons.index) || callees[cons.index] != r->ptr) {
all_cons_pass = false;
break;
}
Expand All @@ -648,13 +646,13 @@ static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m,
}
if (!pchild->matched || !any_cons_sat) return quit();
}

return true;
}

class MatcherUseDefAnalysis : public relax::ExprVisitor {
public:
std::map<const VarNode*, std::set<const VarNode*>> def2use;
std::vector<const VarNode*> vars;
std::map<const VarNode*, std::vector<const VarNode*>> def2use;
// caller -> callee table.
std::map<const VarNode*, std::vector<const VarNode*>> caller2callees;

Expand All @@ -671,7 +669,15 @@ class MatcherUseDefAnalysis : public relax::ExprVisitor {
void VisitExpr_(const VarNode* op) override {
if (nullptr == cur_user_) return;

def2use[op].insert(cur_user_);
auto check_and_push = [](std::vector<const VarNode*>& vec, const VarNode* var) {
if (std::find(vec.begin(), vec.end(), var) == vec.end()) {
vec.push_back(var);
}
};

check_and_push(def2use[op], cur_user_);
check_and_push(vars, op);

caller2callees[cur_user_].push_back(op);
}

Expand All @@ -682,6 +688,10 @@ class MatcherUseDefAnalysis : public relax::ExprVisitor {

Map<DFPattern, Var> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb,
Optional<Var> start_hint, bool must_include_hint) {
if (ctx->src_ordered.size() == 0) {
return {};
}

Map<DFPattern, Var> ret;
// TODO(@ganler): Handle non-may external use.
ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet.";
Expand All @@ -691,7 +701,6 @@ Map<DFPattern, Var> MatchGraph(const PatternContext& ctx, const DataflowBlock& d
const auto var2val = AnalyzeVar2Value(dfb);
DFPatternMatcher matcher(var2val);

// std::map<const VarNode*, std::set<const VarNode*>>
MatcherUseDefAnalysis ud_analysis;
ud_analysis.VisitBindingBlock_(dfb.get());
const auto& def2use = ud_analysis.def2use;
Expand All @@ -701,9 +710,8 @@ Map<DFPattern, Var> MatchGraph(const PatternContext& ctx, const DataflowBlock& d
std::unordered_map<const VarNode*, RNode> var2node;
var2node.reserve(dfb->bindings.size());

for (const auto& du : def2use) {
const VarNode* cur_var = du.first;
const std::set<const VarNode*>& uses = du.second;
for (const VarNode* cur_var : ud_analysis.vars) {
const auto& uses = def2use.at(cur_var);
RNode& cur_node = var2node[cur_var];
cur_node.ptr = cur_var;
for (const VarNode* use : uses) {
Expand All @@ -717,44 +725,40 @@ Map<DFPattern, Var> MatchGraph(const PatternContext& ctx, const DataflowBlock& d
std::unordered_map<const DFPatternNode*, PNode> pattern2node;
pattern2node.reserve(ctx->constraints.size());

for (const auto& def2use_pattern : ctx->constraints) {
const DFPatternNode* def_pattern = def2use_pattern.first.get();
const std::map<DFPattern, std::vector<PairCons>>& uses = def2use_pattern.second;
PNode& def_node = pattern2node[def_pattern];
def_node.ptr = def_pattern;
for (const auto& [def_pattern, uses] : ctx->constraints) {
PNode& def_node = pattern2node[def_pattern.get()];
def_node.ptr = def_pattern.get();
def_node.children.reserve(uses.size());
for (const auto& use : uses) {
const auto& cons = use.second;
const DFPatternNode* use_pattern = use.first.get();
PNode& use_node = pattern2node[use_pattern];
use_node.ptr = use_pattern;
for (const auto& [use_pattern, cons] : uses) {
PNode& use_node = pattern2node[use_pattern.get()];
use_node.ptr = use_pattern.get();
use_node.parents.emplace_back(&def_node, std::ref(cons));
def_node.children.emplace_back(&use_node, std::ref(cons));
}
}

if (start_hint.defined()) {
Var v = start_hint.value();
auto rnode_ptr = var2node.find(v.get());
for (auto& ppair : pattern2node) {
if (try_match(&ppair.second, &rnode_ptr->second, &matcher, def2use, caller2callees)) {
for (auto ppair : pattern2node)
ret.Set(GetRef<DFPattern>(ppair.first), GetRef<Var>(ppair.second.matched));
if (start_hint) {
auto rnode_ptr = var2node.at(start_hint.value().get());
for (auto& p_node : pattern2node) {
if (try_match(&p_node.second, &rnode_ptr, &matcher, def2use, caller2callees)) {
for (const auto& [df_pattern, pattern_node] : pattern2node)
ret.Set(GetRef<DFPattern>(df_pattern), GetRef<Var>(pattern_node.matched));
return ret;
}
}

if (must_include_hint) return ret;
}

PNode* pnode_start = &pattern2node.begin()->second;
PNode& pnode_start = pattern2node[ctx->src_ordered[0].get()];

if (!pnode_start->matched) {
for (auto& rpair : var2node) {
if (start_hint.defined() && start_hint.value().get() == rpair.first) continue;
if (try_match(pnode_start, &rpair.second, &matcher, def2use, caller2callees)) {
for (auto ppair : pattern2node)
ret.Set(GetRef<DFPattern>(ppair.first), GetRef<Var>(ppair.second.matched));
if (!pnode_start.matched) {
for (const auto& var : ud_analysis.vars) {
if (start_hint.defined() && start_hint.value().get() == var) continue;
RNode& r_node = var2node[var];
if (try_match(&pnode_start, &r_node, &matcher, def2use, caller2callees)) {
for (const auto& [df_pattern, pattern_node] : pattern2node)
ret.Set(GetRef<DFPattern>(df_pattern), GetRef<Var>(pattern_node.matched));

return ret;
}
Expand Down
45 changes: 45 additions & 0 deletions tests/python/relax/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,5 +1006,50 @@ def rewriter(_, matchings):
tvm.ir.assert_structural_equal(rewritten, expected)


def test_attention_qkv():
@tvm.script.ir_module
class QKV_proj:
@R.function
def main(
x: R.Tensor((2, 1024, 640), "float32"),
w0: R.Tensor((640, 640), "float32"),
w1: R.Tensor((640, 640), "float32"),
w2: R.Tensor((640, 640), "float32"),
) -> R.Tensor:
with R.dataflow():
lv0 = R.matmul(x, w0)
lv1 = R.matmul(x, w1)
lv2 = R.matmul(x, w2)
out = (lv0, lv1, lv2)
R.output(out)
return out

with PatternContext() as ctx:
inp_pat = wildcard()
Q_weight_pat = wildcard()
K_weight_pat = wildcard()
V_weight_pat = wildcard()

matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)

# TODO(masahi): Automate addition of used_by constraints during is_op
Copy link
Member Author

@masahi masahi Mar 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a follow-up PR to address this, which removes all used_by stuff below @ganler

inp_pat.used_by(matmul1, 0)
inp_pat.used_by(matmul2, 0)
inp_pat.used_by(matmul3, 0)

Q_weight_pat.only_used_by(matmul1, 1)
K_weight_pat.only_used_by(matmul2, 1)
V_weight_pat.only_used_by(matmul3, 1)

dfb = QKV_proj["main"].body.blocks[0]
out = ctx.match_dfb(dfb)

assert out[Q_weight_pat].name_hint == "w0"
assert out[K_weight_pat].name_hint == "w1"
assert out[V_weight_pat].name_hint == "w2"


if __name__ == "__main__":
tvm.testing.main()