Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -672,9 +672,17 @@ TVM_DLL Pass InjectPTXAsyncCopy();

/*!
* \brief Remove the weight layout rewrite block
* \param skip_ndarray_rewrite If True, exact rewrite of NDArray, according to the given index map,
* will be skipped. Only the shape of the NDArray is transformed correctly, and the content of
* the destination array will be filled with random values.
*
* When this pass is called many times during MetaSchedule tuning, the raw data of NDArray,
* before and after rewrite, does not matter. Since NDArray layout rewrite, using IndexMap's
* MapNDArray, is currently slow, skipping the exact rewrite is sometimes necessary.
*
* \return The pass.
*/
TVM_DLL Pass RemoveWeightLayoutRewriteBlock();
TVM_DLL Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite = false);

/*!
* \brief Add the explicit local stage for the shared memory access on GPU.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/builder/local_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def default_build(mod: IRModule, target: Target, _params: Optional[Dict[str, NDA
from tvm.tir.transform import RemoveWeightLayoutRewriteBlock

# pylint: enable=import-outside-toplevel
mod = RemoveWeightLayoutRewriteBlock()(mod)
mod = RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite=True)(mod)
return tvm_build(mod, target=target)


Expand Down
16 changes: 14 additions & 2 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,14 +964,26 @@ def InjectPTXAsyncCopy():
return _ffi_api.InjectPTXAsyncCopy() # type: ignore


def RemoveWeightLayoutRewriteBlock():
def RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite=False):
"""Remove weight layout rewrite block before benchmarking during tuning stage.

Parameters
----------
skip_ndarray_rewrite : bool
If True, exact rewrite of NDArray, according to the given index map, will be skipped.
Only the shape of the NDArray is transformed correctly, and the content of the destination
array will be filled with random values.

When this pass is called many times during MetaSchedule tuning, the raw data of NDArray,
before and after rewrite, does not matter. Since NDArray layout rewrite, using IndexMap's
MapNDArray, is currently slow, skipping the exact rewrite is sometimes necessary.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.RemoveWeightLayoutRewriteBlock() # type: ignore
return _ffi_api.RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite) # type: ignore


def ManifestSharedMemoryLocalStage():
Expand Down
3 changes: 2 additions & 1 deletion src/meta_schedule/arg_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ Array<ArgInfo> ArgInfo::FromPrimFunc(const tir::PrimFunc& func) {

Array<ArgInfo> ArgInfo::FromEntryFunc(const IRModule& mod, bool remove_preproc) {
if (remove_preproc) {
IRModule new_mod = tir::transform::RemoveWeightLayoutRewriteBlock()(mod);
IRModule new_mod =
tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_ndarray_rewrite*/ true)(mod);
return ArgInfo::FromPrimFunc(FindEntryFunc(new_mod));
}
return ArgInfo::FromPrimFunc(FindEntryFunc(mod));
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/feature_extractor/per_store_feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ Pass SimplifyForFeatureExtraction() {
*/
Sequential PassListForPerStoreFeature() {
return Sequential({
tir::transform::RemoveWeightLayoutRewriteBlock(),
tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_ndarray_rewrite*/ true),
tir::transform::SimplifyForFeatureExtraction(),
tir::transform::LowerCrossThreadReduction(),
tir::transform::LowerInitBlock(),
Expand Down
3 changes: 2 additions & 1 deletion src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,8 @@ class ScheduleBuilder : public ExprVisitor {
record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false);
IRModule mod = sch->mod();
ICHECK_EQ(mod->functions.size(), 1);
mod = tir::transform::RemoveWeightLayoutRewriteBlock()(std::move(mod));
mod = tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_ndarray_rewrite*/ false)(
std::move(mod));
prim_func = Downcast<PrimFunc>(mod->Lookup("main"));
// Need to copy attrs from relay function over to prim func. Most notably the structural
// hash.
Expand Down
59 changes: 47 additions & 12 deletions src/tir/transforms/remove_weight_layout_rewrite_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ namespace tir {

class RemoveLayoutRewriteBlock : public StmtMutator {
public:
static std::tuple<PrimFunc, Map<Buffer, Buffer>, std::unordered_map<const VarNode*, IndexMap>>
static std::tuple<PrimFunc, Map<Buffer, Buffer>, std::unordered_map<const VarNode*, IndexMap>,
std::unordered_map<const VarNode*, Array<PrimExpr>>>
Rewrite(PrimFunc f) {
RemoveLayoutRewriteBlock rewriter;

PrimFuncNode* n = f.CopyOnWrite();
n->body = rewriter(std::move(n->body));
return std::make_tuple(f, rewriter.buf_map_, rewriter.buffer_var_to_index_map_);
return std::make_tuple(f, rewriter.buf_map_, rewriter.buffer_var_to_index_map_,
rewriter.buffer_var_to_rewritten_shape_);
}

private:
Expand Down Expand Up @@ -95,6 +97,8 @@ class RemoveLayoutRewriteBlock : public StmtMutator {
}
buffer_var_to_index_map_[load->buffer->data.get()] = IndexMap(load_indices, store->indices);

buffer_var_to_rewritten_shape_[load->buffer->data.get()] = store->buffer->shape;

return Stmt(n);
}

Expand All @@ -106,6 +110,8 @@ class RemoveLayoutRewriteBlock : public StmtMutator {
/*! \brief Maps a buffer load to an index map associated with the load / store
in a layout rewrite block. */
std::unordered_map<const VarNode*, IndexMap> buffer_var_to_index_map_;
/*! \brief Maps a buffer load to the shape of the corresponding rewritten buffer. */
std::unordered_map<const VarNode*, Array<PrimExpr>> buffer_var_to_rewritten_shape_;
};

// After RemoveLayoutRewriteBlock, the body of a compute update block references a
Expand Down Expand Up @@ -139,9 +145,15 @@ using BufferVarMap = std::unordered_map<const tir::VarNode*, const tir::VarNode*

class AllocateConstRewrite : public StmtExprMutator {
public:
AllocateConstRewrite(const BufferVarMap& buffer_var_map,
const std::unordered_map<const VarNode*, IndexMap>& buffer_var_to_index_map)
: buffer_var_map_(buffer_var_map), buffer_var_to_index_map_(buffer_var_to_index_map) {}
AllocateConstRewrite(
const BufferVarMap& buffer_var_map,
const std::unordered_map<const VarNode*, IndexMap>& buffer_var_to_index_map,
const std::unordered_map<const VarNode*, Array<PrimExpr>>& buffer_var_to_rewritten_shape,
bool skip_ndarray_rewrite)
: buffer_var_map_(buffer_var_map),
buffer_var_to_index_map_(buffer_var_to_index_map),
buffer_var_to_rewritten_shape_(buffer_var_to_rewritten_shape),
skip_ndarray_rewrite_(skip_ndarray_rewrite) {}

private:
Stmt VisitStmt_(const BlockNode* op) final {
Expand All @@ -163,8 +175,10 @@ class AllocateConstRewrite : public StmtExprMutator {
Stmt VisitStmt_(const AllocateConstNode* alloc) final {
if (auto it = buffer_var_to_index_map_.find(alloc->buffer_var.get());
it != buffer_var_to_index_map_.end()) {
ICHECK(buffer_var_to_rewritten_shape_.count(alloc->buffer_var.get()));
auto new_body = StmtMutator::VisitStmt(alloc->body);
auto rewritten_ndarray = it->second->MapNDArray(alloc->data.value());
auto rewritten_ndarray = RewriteNDArray(
alloc->data.value(), it->second, buffer_var_to_rewritten_shape_[alloc->buffer_var.get()]);
Array<PrimExpr> rewritten_extents;
for (auto s : rewritten_ndarray.Shape()) {
rewritten_extents.push_back(PrimExpr(static_cast<int>(s)));
Expand All @@ -187,13 +201,32 @@ class AllocateConstRewrite : public StmtExprMutator {
return ExprMutator::VisitExpr_(op);
}

runtime::NDArray RewriteNDArray(runtime::NDArray src, const IndexMap& index_map,
const Array<PrimExpr>& dst_shape) {
if (skip_ndarray_rewrite_) {
// Only the shape of the destination array needs to be correct.
std::vector<int64_t> dst_shape_int;
for (auto s : dst_shape) {
ICHECK(s->IsInstance<IntImmNode>());
dst_shape_int.push_back(s.as<IntImmNode>()->value);
}
return src.CreateView(dst_shape_int, src.DataType());
} else {
return index_map->MapNDArray(src);
}
}

/*! \brief Maps a buffer store to a load in a layout rewrite block */
BufferVarMap buffer_var_map_;
/*! \brief Maps a buffer load to an index map associated with the load / store
in a layout rewrite block. */
std::unordered_map<const VarNode*, IndexMap> buffer_var_to_index_map_;
/*! \brief Maps a buffer load to the shape of the corresponding rewritten buffer. */
std::unordered_map<const VarNode*, Array<PrimExpr>> buffer_var_to_rewritten_shape_;
/*! \brief Maps load buffer variables to newly created buffers */
std::unordered_map<const VarNode*, Buffer> new_load_buf_;
/*! \brief Whether or not to skip rewriting of NDArray contents */
bool skip_ndarray_rewrite_;
};

class CollectAllocateConstBufferVars : public StmtVisitor {
Expand All @@ -208,11 +241,12 @@ class CollectAllocateConstBufferVars : public StmtVisitor {

class WeightLayoutRewriteBlockRemover : public StmtMutator {
public:
static PrimFunc Remove(PrimFunc f) {
static PrimFunc Remove(PrimFunc f, bool skip_ndarray_rewrite) {
CollectAllocateConstBufferVars collector;
collector(f->body);

auto [f_, buf_map, buffer_var_to_index_map] = RemoveLayoutRewriteBlock().Rewrite(f);
auto [f_, buf_map, buffer_var_to_index_map, buffer_var_to_rewritten_shape] =
RemoveLayoutRewriteBlock().Rewrite(f);

BufferVarMap buffer_var_map;
for (const auto& [load_buf, store_buf] : buf_map) {
Expand All @@ -224,7 +258,8 @@ class WeightLayoutRewriteBlockRemover : public StmtMutator {

PrimFuncNode* n = f_.CopyOnWrite();

AllocateConstRewrite rewriter(buffer_var_map, buffer_var_to_index_map);
AllocateConstRewrite rewriter(buffer_var_map, buffer_var_to_index_map,
buffer_var_to_rewritten_shape, skip_ndarray_rewrite);
n->body = rewriter(std::move(n->body));

Map<tir::Var, Buffer> buffer_map;
Expand All @@ -243,9 +278,9 @@ class WeightLayoutRewriteBlockRemover : public StmtMutator {

namespace transform {

Pass RemoveWeightLayoutRewriteBlock() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
return WeightLayoutRewriteBlockRemover::Remove(std::move(f));
Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite) {
auto pass_func = [skip_ndarray_rewrite](PrimFunc f, IRModule m, PassContext ctx) {
return WeightLayoutRewriteBlockRemover::Remove(std::move(f), skip_ndarray_rewrite);
};
return CreatePrimFuncPass(pass_func, 0, "tir.RemoveWeightLayoutRewriteBlock", {});
}
Expand Down
43 changes: 22 additions & 21 deletions tests/python/unittest/test_meta_schedule_relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,32 +539,33 @@ def test_rewrite_layout_link_params():
executor = relay.backend.Executor("graph", {"link-params": link_params})
mod = mod.with_attr("executor", executor)

with tempfile.TemporaryDirectory() as work_dir:
database = ms.relay_integration.tune_relay(
mod=mod,
target=target,
params=params,
work_dir=work_dir,
max_trials_global=4,
strategy="replay-trace",
)
for strategy in ["replay-trace", "evolutionary"]:
with tempfile.TemporaryDirectory() as work_dir:
database = ms.relay_integration.tune_relay(
mod=mod,
target=target,
params=params,
work_dir=work_dir,
max_trials_global=4,
strategy=strategy,
)

lib = ms.relay_integration.compile_relay(
database=database,
mod=mod,
target=target,
params=params,
)
lib = ms.relay_integration.compile_relay(
database=database,
mod=mod,
target=target,
params=params,
)

dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))

runtime.set_input("data", data_np)
runtime.run()
runtime.set_input("data", data_np)
runtime.run()

out = runtime.get_output(0).numpy()
out = runtime.get_output(0).numpy()

np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4)
np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4)


if __name__ == "__main__":
Expand Down