Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/relax/dpl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@

from .pattern import *
from .context import *
from .rewrite import rewrite_call, rewrite_bindings
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing rewrite function is renamed to rewrite_call to make it clear that it is CallNode rewriting. And together with the new dataflow block rewriting function, it is put under the new file.

40 changes: 2 additions & 38 deletions python/tvm/relax/dpl/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# pylint: disable=pointless-statement

import typing
from typing import Dict, List, Optional, Tuple, Union, Callable
from typing import Dict, List, Optional, Tuple, Union

import tvm
import tvm._ffi as tvm_ffi
Expand All @@ -31,7 +31,7 @@
from ...ir import make_node
from ...ir.base import Node
from ...runtime import Object
from ..expr import Expr, Var, Function
from ..expr import Expr, Var
from . import _ffi as ffi


Expand Down Expand Up @@ -1123,39 +1123,3 @@ def make_fused_bias_activation_pattern(op_name, with_bias=False, activation=None
return is_op(activation)(out)

return out


def rewrite(
pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function
) -> Function:
"""
Rewrite a function with the given pattern and the rewriter function.

Parameters
----------
pattern: DFPattern
The pattern to match.

rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr]
The function to be called on a successful matching for rewriting. Given the matched
call node and the map of patterns and matched expressions, it should return a new call node
to replace the original one or the original matched call node as is.

For example, to replace x + x with 2 * x, we can write the rewriter as follows:
```
x = wildcard()
pattern = is_op("relax.add")(x, x)

def rewriter(orig, matchings):
return R.multiply(matchings[x], R.const(2, "float32"))
```

func: Function
The function to rewrite.

Returns
-------
rewritten_func: Function
The rewritten or the input function, depending on the pattern matching result.
"""
return ffi.rewrite(pattern, rewriter, func)
115 changes: 115 additions & 0 deletions python/tvm/relax/dpl/rewrite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""APIs for pattern-based rewriting."""
from typing import Dict, Callable
from .pattern import DFPattern
from .context import PatternContext

from ..expr import Expr, Function, Var
from . import _ffi as ffi


def rewrite_call(
pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function
) -> Function:
"""
Rewrite a function with the given pattern and the rewriter function.

Parameters
----------
pattern: DFPattern
The pattern to match.

rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr]
The function to be called on a successful matching for rewriting. Given the matched
call node and the map of patterns and matched expressions, it should return a new call node
to replace the original one or the original matched call node as is.

For example, to replace x + x with 2 * x, we can write the rewriter as follows:
```
x = wildcard()
pattern = is_op("relax.add")(x, x)

def rewriter(orig, matchings):
return R.multiply(matchings[x], R.const(2, "float32"))
```

func: Function
The function to rewrite.

Returns
-------
rewritten_func: Function
The rewritten or the input function, depending on the pattern matching result.
"""
return ffi.rewrite_call(pattern, rewriter, func)


def rewrite_bindings(
ctx: PatternContext, rewriter: Callable[[Dict[DFPattern, Var]], Dict[Var, Expr]], func: Function
) -> Function:
"""
Rewrite a function with the given pattern and the rewriter function.

Parameters
----------
ctx: PatternContext
The pattern constraint context under which rewriting takes place.

rewriter: Callable[[Dict[DFPattern, Var]], Dict[Var, Expr]]
The function to be called on a successful matching for rewriting. Given the map of patterns
and corresponding variables (bound variables or parameters), it should return a map that
specifies new values for matched bound variables.

For example, to rewrite three matmuls for QKV projection in transformer models into one
matmul followed by slicing, one can use the follwoing rewriter:
```
inp_pat = wildcard()
Q_weight_pat, K_weight_pat, V_weight_pat = wildcard(), wildcard(), wildcard()

matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)

def rewriter(matchings):
inp = matchings[inp_pat]
Q_weight = matchings[Q_weight_pat]
K_weight = matchings[K_weight_pat]
V_weight = matchings[V_weight_pat]
width = Q_weight.struct_info.shape[1]

concat = R.concat([Q_weight, K_weight, V_weight], axis=1)
matmul = R.matmul(inp, concat)
Q = R.strided_slice(matmul, axes=[2], begin=[0], end=[width])
K = R.strided_slice(matmul, axes=[2], begin=[width], end=[width * 2])
V = R.strided_slice(matmul, axes=[2], begin=[width * 2], end=[width * 3])

# matchings[matmul1] gives the bound variable in the binding whose RHS matches with
# the matmul1 pattern. For example, lv0 in lv0 = R.matmul(x1, w0).
# We want to replace the RHS of this binding with Q.
return {matchings[matmul1]: Q, matchings[matmul2]: K, matchings[matmul3]: V}
```
Copy link
Member Author

@masahi masahi Mar 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope the API for the rewriter makes sense and the usage is intuitive. It took me a while to workout this interface together with how the rewriting mutator should be implemented in dataflow_matcher.cc.

cc @ganler


func: Function
The function to rewrite.

Returns
-------
rewritten_func: Function
The rewritten or the input function, depending on the pattern matching result.
"""
return ffi.rewrite_bindings(ctx, rewriter, func)
116 changes: 107 additions & 9 deletions src/relax/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -780,18 +780,29 @@ Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const Datafl
TVM_REGISTER_GLOBAL("relax.dpl.match_dfb").set_body_typed(MatchGraph);

/*!
* \brief Apply pattern matching to each call node and replace matching ones with the output of
* a user-provided rewriter function.
* \brief Apply pattern matching to each call node and dataflow block, and replace matching ones
* with the output of a user-provided rewriter function.
*/
class PatternRewriter : ExprMutator {
public:
using ExprMutator::VisitBindingBlock_;
using ExprMutator::VisitExpr_;

PatternRewriter(DFPattern pat, PackedFunc rewriter_func)
: pattern_(pat), rewriter_func_(rewriter_func) {}
PatternRewriter(DFPattern pat, PackedFunc rewriter_func,
const std::unordered_set<const VarNode*>& params)
: pattern_(pat), rewriter_func_(rewriter_func), params_(params) {}

static Expr Run(DFPattern pat, PackedFunc rewriter_func, Function f) {
PatternRewriter rewriter(pat, rewriter_func);
PatternRewriter(const PatternContext& ctx, PackedFunc rewriter_func,
const std::unordered_set<const VarNode*>& params)
: ctx_(ctx), rewriter_func_(rewriter_func), params_(params) {}

template <typename PatternType>
static Expr Run(PatternType pat, PackedFunc rewriter_func, Function f) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on the passed pattern type (DFPattern or PatternContext), it does either call node rewriting or dataflow block rewriting. It never does both in a single pass (obvious from the constructors).

std::unordered_set<const VarNode*> params;
for (const auto& p : f->params) {
params.insert(p.get());
}
PatternRewriter rewriter(pat, rewriter_func, params);
return RemoveAllUnused(Downcast<Function>(rewriter.VisitExpr(f)));
}

Expand All @@ -807,25 +818,112 @@ class PatternRewriter : ExprMutator {

Expr VisitExpr_(const CallNode* call_node) final {
auto call = ExprMutator::VisitExpr_(call_node);
if (auto matches_opt = ExtractMatchedExpr(pattern_, call, bindings_)) {
if (!pattern_) {
return call;
} else if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), call, bindings_)) {
auto rewriten_expr = rewriter_func_(call, matches_opt.value());
memo_[call_node] = rewriten_expr;
return rewriten_expr;
}
return call;
}

BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) final {
if (!ctx_) {
return ExprMutator::VisitBindingBlock_(block_node);
}
return RewriteDataflowBlockFixedPoint(GetRef<DataflowBlock>(block_node));
}

private:
DFPattern pattern_;
void EmitUsedVars(Expr val, const Array<Binding>& pending_bindings,
std::unordered_set<const VarNode*>* emitted_vars) {
std::unordered_set<const VarNode*> unemitted_vars;
PostOrderVisit(val, [=, &unemitted_vars](Expr e) {
if (auto v = e.as<VarNode>(); v && !emitted_vars->count(v)) {
unemitted_vars.insert(v);
}
});

if (unemitted_vars.empty()) {
return;
}

size_t num_unemitted = unemitted_vars.size();
for (size_t i = 0; i < pending_bindings.size(); ++i) {
const auto& binding = pending_bindings[i];
if (auto var_bind = binding.as<VarBindingNode>();
var_bind && unemitted_vars.count(var_bind->var.get())) {
// var_bind->value may also depend on other unemitted vars in this range
Array<Binding> prev_bindings(pending_bindings.begin(), pending_bindings.begin() + i);
EmitUsedVars(var_bind->value, prev_bindings, emitted_vars);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to get rid of this recursive call and make sure we traverse pending_bindings only once. The issue is that PostOrderVisit does not look into subexpressions when it encounters the corresponding bound variable. For example, given the contrived input bindings below,

with R.dataflow():
    w0_t = R.permute_dims(w0, axes=None)
    lv0 = R.matmul(x1, w0_t)
    w1_t = R.permute_dims(w1, axes=None)
    w1_t_t = R.permute_dims(w1_t, axes=None)
    lv1 = R.matmul(x1, w1_t_t)
    w2_t = R.permute_dims(w2, axes=None)
    lv2 = R.matmul(x1, w2_t)

we need to emit all permute_dims binding before emitting concat and the combined matmul, since concat depends on all weights some of which are defined after the first matmul. When PostOrderVisit is applied on R.matmul(x1, w1_t_t), w1_t is not visited. So even though we need to emit w1_t before w1_t_t, w1_t is not put into the initial unemitted_vars set.

I think we can use AnalyzeVar2Value on the input function to get bindings, and recursively traverse the bound expression when we encounter a new unemitted var. But I find that a bit complicated for a simple job like this, so I'm looking for a simpler solution. For now I'm keeping this recursive solution that is not efficient but extremely simple.

this->VisitBinding(binding);
emitted_vars->insert(var_bind->var.get());
if (--num_unemitted == 0) {
return;
}
}
}
}

// Repeat until all matchable subsets of bindings are rewritten.
BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to apply rewriting repeatedly since, for example, the same QKV projection pattern appears a number of times in a single DFB.

if (auto matches = MatchGraph(ctx_.value(), Downcast<DataflowBlock>(block))) {
builder_->BeginDataflowBlock();
Map<Var, Expr> replacements = rewriter_func_(matches.value());

std::unordered_set<const VarNode*> emitted_vars;

for (size_t i = 0; i < block->bindings.size(); ++i) {
const auto& binding = block->bindings[i];
if (auto var_bind = binding.as<VarBindingNode>()) {
if (replacements.count(var_bind->var)) {
auto new_val = replacements[var_bind->var];
Array<Binding> pending_bindings(block->bindings.begin() + i + 1, block->bindings.end());
// Make sure there is no unbound variable used in the new value before it is emitted
EmitUsedVars(new_val, pending_bindings, &emitted_vars);
this->ReEmitBinding(var_bind, builder_->Normalize(new_val));
} else if (!emitted_vars.count(var_bind->var.get())) {
this->VisitBinding(binding);
emitted_vars.insert(var_bind->var.get());
}
} else {
this->VisitBinding(binding);
}
}
return RewriteDataflowBlockFixedPoint(builder_->EndBlock());
}
return block;
}

/*! \brief The pattern for rewriting call nodes */
Optional<DFPattern> pattern_;
/*! \brief The pattern constraint contexts for rewriting dataflow blocks */
Optional<PatternContext> ctx_;
/*!
* \brief The user-provided rewriter function. Its signature and semantics are:
* - (Call, Map<DFPattern, Expr>) -> Call for call node rewriting. Given the matched
* call node and the map of patterns and matched expressions, it should return a new call node
* to replace the original one or the original matched call node as is.
* - Map<DFPattern, Var> -> Map<Var, Expr> for dataflow block rewriting. Given the map of patterns
* and corresponding variables (bound variables or parameters), it should return a map that
* specifies new values for matched bound variables.
*/
PackedFunc rewriter_func_;
std::unordered_set<const VarNode*> params_;
Map<Var, Expr> bindings_;
std::unordered_map<const Object*, Expr> memo_;
};

TVM_REGISTER_GLOBAL("relax.dpl.rewrite")
TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call")
.set_body_typed([](DFPattern pat, PackedFunc rewriter, Function f) {
return PatternRewriter::Run(pat, rewriter, f);
});

TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings")
.set_body_typed([](const PatternContext& ctx, PackedFunc rewriter, Function f) {
return PatternRewriter::Run(ctx, rewriter, f);
});

} // namespace relax
} // namespace tvm
Loading