From e0a5a1aff25bc16b42956e810166dd6d0d4a7162 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 12 Oct 2022 18:07:06 +0900 Subject: [PATCH 1/3] Allow skipping exact NDArray rewrite in RemoveWeightLayoutRewriteBlock --- include/tvm/tir/transform.h | 2 +- src/meta_schedule/arg_info.cc | 3 +- src/relay/backend/te_compiler_cache.cc | 3 +- .../remove_weight_layout_rewrite_block.cc | 56 +++++++++++++++---- 4 files changed, 49 insertions(+), 15 deletions(-) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 6aa1aca69970..8f8b10eda74b 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -674,7 +674,7 @@ TVM_DLL Pass InjectPTXAsyncCopy(); * \brief Remove the weight layout rewrite block * \return The pass. */ -TVM_DLL Pass RemoveWeightLayoutRewriteBlock(); +TVM_DLL Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite=false); /*! * \brief Add the explicit local stage for the shared memory access on GPU. diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc index 84d861cb59c3..4663fd90762a 100644 --- a/src/meta_schedule/arg_info.cc +++ b/src/meta_schedule/arg_info.cc @@ -103,7 +103,8 @@ Array ArgInfo::FromPrimFunc(const tir::PrimFunc& func) { Array ArgInfo::FromEntryFunc(const IRModule& mod, bool remove_preproc) { if (remove_preproc) { - IRModule new_mod = tir::transform::RemoveWeightLayoutRewriteBlock()(mod); + IRModule new_mod = + tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_ndarray_rewrite*/ true)(mod); return ArgInfo::FromPrimFunc(FindEntryFunc(new_mod)); } return ArgInfo::FromPrimFunc(FindEntryFunc(mod)); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index a1a4bedfb8b0..9a0a2bef9a47 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -526,7 +526,8 @@ class ScheduleBuilder : public ExprVisitor { record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false); IRModule mod = sch->mod(); ICHECK_EQ(mod->functions.size(), 1); - mod = tir::transform::RemoveWeightLayoutRewriteBlock()(std::move(mod)); + mod = tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_ndarray_rewrite*/ false)( + std::move(mod)); prim_func = Downcast(mod->Lookup("main")); // Need to copy attrs from relay function over to prim func. Most notably the structural // hash. diff --git a/src/tir/transforms/remove_weight_layout_rewrite_block.cc b/src/tir/transforms/remove_weight_layout_rewrite_block.cc index 86f6700f2289..b6585786746e 100644 --- a/src/tir/transforms/remove_weight_layout_rewrite_block.cc +++ b/src/tir/transforms/remove_weight_layout_rewrite_block.cc @@ -34,13 +34,15 @@ namespace tir { class RemoveLayoutRewriteBlock : public StmtMutator { public: - static std::tuple, std::unordered_map> + static std::tuple, std::unordered_map, + std::unordered_map>> Rewrite(PrimFunc f) { RemoveLayoutRewriteBlock rewriter; PrimFuncNode* n = f.CopyOnWrite(); n->body = rewriter(std::move(n->body)); - return std::make_tuple(f, rewriter.buf_map_, rewriter.buffer_var_to_index_map_); + return std::make_tuple(f, rewriter.buf_map_, rewriter.buffer_var_to_index_map_, + rewriter.buffer_var_to_rewritten_shape_); } private: @@ -95,6 +97,8 @@ class RemoveLayoutRewriteBlock : public StmtMutator { } buffer_var_to_index_map_[load->buffer->data.get()] = IndexMap(load_indices, store->indices); + buffer_var_to_rewritten_shape_[load->buffer->data.get()] = store->buffer->shape; + return Stmt(n); } @@ -106,6 +110,7 @@ class RemoveLayoutRewriteBlock : public StmtMutator { /*! \brief Maps a buffer load to an index map associated with the load / store in a layout rewrite block. */ std::unordered_map buffer_var_to_index_map_; + std::unordered_map> buffer_var_to_rewritten_shape_; }; // After RemoveLayoutRewriteBlock, the body of a compute update block references a @@ -139,9 +144,15 @@ using BufferVarMap = std::unordered_map& buffer_var_to_index_map) - : buffer_var_map_(buffer_var_map), buffer_var_to_index_map_(buffer_var_to_index_map) {} + AllocateConstRewrite( + const BufferVarMap& buffer_var_map, + const std::unordered_map& buffer_var_to_index_map, + const std::unordered_map>& buffer_var_to_rewritten_shape, + bool skip_ndarray_rewrite) + : buffer_var_map_(buffer_var_map), + buffer_var_to_index_map_(buffer_var_to_index_map), + buffer_var_to_rewritten_shape_(buffer_var_to_rewritten_shape), + skip_ndarray_rewrite_(skip_ndarray_rewrite) {} private: Stmt VisitStmt_(const BlockNode* op) final { @@ -163,8 +174,10 @@ class AllocateConstRewrite : public StmtExprMutator { Stmt VisitStmt_(const AllocateConstNode* alloc) final { if (auto it = buffer_var_to_index_map_.find(alloc->buffer_var.get()); it != buffer_var_to_index_map_.end()) { + ICHECK(buffer_var_to_rewritten_shape_.count(alloc->buffer_var.get())); auto new_body = StmtMutator::VisitStmt(alloc->body); - auto rewritten_ndarray = it->second->MapNDArray(alloc->data.value()); + auto rewritten_ndarray = RewriteNDArray( + alloc->data.value(), it->second, buffer_var_to_rewritten_shape_[alloc->buffer_var.get()]); Array rewritten_extents; for (auto s : rewritten_ndarray.Shape()) { rewritten_extents.push_back(PrimExpr(static_cast(s))); @@ -187,13 +200,29 @@ class AllocateConstRewrite : public StmtExprMutator { return ExprMutator::VisitExpr_(op); } + runtime::NDArray RewriteNDArray(runtime::NDArray src, const IndexMap& index_map, + const Array& dst_shape) { + if (skip_ndarray_rewrite_) { + std::vector dst_shape_int; + for (auto s : dst_shape) { + ICHECK(s->IsInstance()); + dst_shape_int.push_back(s.as()->value); + } + return src.CreateView(dst_shape_int, src.DataType()); + } else { + return index_map->MapNDArray(src); + } + } + /*! \brief Maps a buffer store to a load in a layout rewrite block */ BufferVarMap buffer_var_map_; /*! \brief Maps a buffer load to an index map associated with the load / store in a layout rewrite block. */ std::unordered_map buffer_var_to_index_map_; + std::unordered_map> buffer_var_to_rewritten_shape_; /*! \brief Maps load buffer variables to newly created buffers */ std::unordered_map new_load_buf_; + bool skip_ndarray_rewrite_; }; class CollectAllocateConstBufferVars : public StmtVisitor { @@ -208,11 +237,12 @@ class CollectAllocateConstBufferVars : public StmtVisitor { class WeightLayoutRewriteBlockRemover : public StmtMutator { public: - static PrimFunc Remove(PrimFunc f) { + static PrimFunc Remove(PrimFunc f, bool skip_ndarray_rewrite) { CollectAllocateConstBufferVars collector; collector(f->body); - auto [f_, buf_map, buffer_var_to_index_map] = RemoveLayoutRewriteBlock().Rewrite(f); + auto [f_, buf_map, buffer_var_to_index_map, buffer_var_to_rewritten_shape] = + RemoveLayoutRewriteBlock().Rewrite(f); BufferVarMap buffer_var_map; for (const auto& [load_buf, store_buf] : buf_map) { @@ -224,7 +254,9 @@ class WeightLayoutRewriteBlockRemover : public StmtMutator { PrimFuncNode* n = f_.CopyOnWrite(); - AllocateConstRewrite rewriter(buffer_var_map, buffer_var_to_index_map); + AllocateConstRewrite rewriter(buffer_var_map, buffer_var_to_index_map, + buffer_var_to_rewritten_shape, + skip_ndarray_rewrite); n->body = rewriter(std::move(n->body)); Map buffer_map; @@ -243,9 +275,9 @@ class WeightLayoutRewriteBlockRemover : public StmtMutator { namespace transform { -Pass RemoveWeightLayoutRewriteBlock() { - auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - return WeightLayoutRewriteBlockRemover::Remove(std::move(f)); +Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite) { + auto pass_func = [skip_ndarray_rewrite](PrimFunc f, IRModule m, PassContext ctx) { + return WeightLayoutRewriteBlockRemover::Remove(std::move(f), skip_ndarray_rewrite); }; return CreatePrimFuncPass(pass_func, 0, "tir.RemoveWeightLayoutRewriteBlock", {}); } From 7ca0ad07510999fb7a25e51ac1a2b169deb58667 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 12 Oct 2022 19:31:33 +0900 Subject: [PATCH 2/3] add doc --- include/tvm/tir/transform.h | 10 +++++++++- .../tvm/meta_schedule/builder/local_builder.py | 2 +- python/tvm/tir/transform/transform.py | 16 ++++++++++++++-- .../feature_extractor/per_store_feature.cc | 2 +- .../remove_weight_layout_rewrite_block.cc | 9 ++++++--- 5 files changed, 31 insertions(+), 8 deletions(-) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 8f8b10eda74b..e31919fbd223 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -672,9 +672,17 @@ TVM_DLL Pass InjectPTXAsyncCopy(); /*! * \brief Remove the weight layout rewrite block + * \param skip_ndarray_rewrite If True, exact rewrite of NDArray, according to the given index map, + * will be skipped. Only the shape of the NDArray is transformed correctly, and the content of + * the destination array will be filled with random values. + * + * When this pass is called many times during MetaSchedule tuning, the raw data of NDArray, + * before and after rewrite, does not matter. Since NDArray layout rewrite, using IndexMap's + * MapNDArray, is currently slow, skipping the exact rewrite is sometimes necessary. + * * \return The pass. */ -TVM_DLL Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite=false); +TVM_DLL Pass RemoveWeightLayoutRewriteBlock(bool skip_ndarray_rewrite = false); /*! * \brief Add the explicit local stage for the shared memory access on GPU. diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py index 6e282d8cb62d..3ddca032ef76 100644 --- a/python/tvm/meta_schedule/builder/local_builder.py +++ b/python/tvm/meta_schedule/builder/local_builder.py @@ -257,7 +257,7 @@ def default_build(mod: IRModule, target: Target, _params: Optional[Dict[str, NDA from tvm.tir.transform import RemoveWeightLayoutRewriteBlock # pylint: enable=import-outside-toplevel - mod = RemoveWeightLayoutRewriteBlock()(mod) + mod = RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite=True)(mod) return tvm_build(mod, target=target) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index d95d15c0dfbe..7b3a81acc525 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -964,14 +964,26 @@ def InjectPTXAsyncCopy(): return _ffi_api.InjectPTXAsyncCopy() # type: ignore -def RemoveWeightLayoutRewriteBlock(): +def RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite=False): """Remove weight layout rewrite block before benchmarking during tuning stage. + + Parameters + ---------- + skip_ndarray_rewrite : bool + If True, exact rewrite of NDArray, according to the given index map, will be skipped. + Only the shape of the NDArray is transformed correctly, and the content of the destination + array will be filled with random values. + + When this pass is called many times during MetaSchedule tuning, the raw data of NDArray, + before and after rewrite, does not matter. Since NDArray layout rewrite, using IndexMap's + MapNDArray, is currently slow, skipping the exact rewrite is sometimes necessary. + Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.RemoveWeightLayoutRewriteBlock() # type: ignore + return _ffi_api.RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite) # type: ignore def ManifestSharedMemoryLocalStage(): diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc index 422f21abe17a..f0459785f352 100644 --- a/src/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -301,7 +301,7 @@ Pass SimplifyForFeatureExtraction() { */ Sequential PassListForPerStoreFeature() { return Sequential({ - tir::transform::RemoveWeightLayoutRewriteBlock(), + tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_ndarray_rewrite*/ true), tir::transform::SimplifyForFeatureExtraction(), tir::transform::LowerCrossThreadReduction(), tir::transform::LowerInitBlock(), diff --git a/src/tir/transforms/remove_weight_layout_rewrite_block.cc b/src/tir/transforms/remove_weight_layout_rewrite_block.cc index b6585786746e..05b636f11403 100644 --- a/src/tir/transforms/remove_weight_layout_rewrite_block.cc +++ b/src/tir/transforms/remove_weight_layout_rewrite_block.cc @@ -110,6 +110,7 @@ class RemoveLayoutRewriteBlock : public StmtMutator { /*! \brief Maps a buffer load to an index map associated with the load / store in a layout rewrite block. */ std::unordered_map buffer_var_to_index_map_; + /*! \brief Maps a buffer load to the shape of the corresponding rewritten buffer. */ std::unordered_map> buffer_var_to_rewritten_shape_; }; @@ -148,7 +149,7 @@ class AllocateConstRewrite : public StmtExprMutator { const BufferVarMap& buffer_var_map, const std::unordered_map& buffer_var_to_index_map, const std::unordered_map>& buffer_var_to_rewritten_shape, - bool skip_ndarray_rewrite) + bool skip_ndarray_rewrite) : buffer_var_map_(buffer_var_map), buffer_var_to_index_map_(buffer_var_to_index_map), buffer_var_to_rewritten_shape_(buffer_var_to_rewritten_shape), @@ -203,6 +204,7 @@ class AllocateConstRewrite : public StmtExprMutator { runtime::NDArray RewriteNDArray(runtime::NDArray src, const IndexMap& index_map, const Array& dst_shape) { if (skip_ndarray_rewrite_) { + // Only the shape of the destination array needs to be correct. std::vector dst_shape_int; for (auto s : dst_shape) { ICHECK(s->IsInstance()); @@ -219,9 +221,11 @@ class AllocateConstRewrite : public StmtExprMutator { /*! \brief Maps a buffer load to an index map associated with the load / store in a layout rewrite block. */ std::unordered_map buffer_var_to_index_map_; + /*! \brief Maps a buffer load to the shape of the corresponding rewritten buffer. */ std::unordered_map> buffer_var_to_rewritten_shape_; /*! \brief Maps load buffer variables to newly created buffers */ std::unordered_map new_load_buf_; + /*! \brief Whether or not to skip rewriting of NDArray contents */ bool skip_ndarray_rewrite_; }; @@ -255,8 +259,7 @@ class WeightLayoutRewriteBlockRemover : public StmtMutator { PrimFuncNode* n = f_.CopyOnWrite(); AllocateConstRewrite rewriter(buffer_var_map, buffer_var_to_index_map, - buffer_var_to_rewritten_shape, - skip_ndarray_rewrite); + buffer_var_to_rewritten_shape, skip_ndarray_rewrite); n->body = rewriter(std::move(n->body)); Map buffer_map; From 02da79ee3160c50354f2041204392ed270d6a384 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 12 Oct 2022 19:39:16 +0900 Subject: [PATCH 3/3] add test --- .../test_meta_schedule_relay_integration.py | 43 ++++++++++--------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index 4047f44ac365..d5c81bcc56ba 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -539,32 +539,33 @@ def test_rewrite_layout_link_params(): executor = relay.backend.Executor("graph", {"link-params": link_params}) mod = mod.with_attr("executor", executor) - with tempfile.TemporaryDirectory() as work_dir: - database = ms.relay_integration.tune_relay( - mod=mod, - target=target, - params=params, - work_dir=work_dir, - max_trials_global=4, - strategy="replay-trace", - ) + for strategy in ["replay-trace", "evolutionary"]: + with tempfile.TemporaryDirectory() as work_dir: + database = ms.relay_integration.tune_relay( + mod=mod, + target=target, + params=params, + work_dir=work_dir, + max_trials_global=4, + strategy=strategy, + ) - lib = ms.relay_integration.compile_relay( - database=database, - mod=mod, - target=target, - params=params, - ) + lib = ms.relay_integration.compile_relay( + database=database, + mod=mod, + target=target, + params=params, + ) - dev = tvm.device(target, 0) - runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + dev = tvm.device(target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) - runtime.set_input("data", data_np) - runtime.run() + runtime.set_input("data", data_np) + runtime.run() - out = runtime.get_output(0).numpy() + out = runtime.get_output(0).numpy() - np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4) + np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4) if __name__ == "__main__":