From 9f55bb83e8999161278d1e370358116e932fc37a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 14 Nov 2022 18:55:31 +0900 Subject: [PATCH 1/9] Update RewriteLayout to support schedules with cache read --- src/meta_schedule/postproc/rewrite_layout.cc | 138 ++++++++++++------- 1 file changed, 88 insertions(+), 50 deletions(-) diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc index 3aed6680e30d..0e4e1fe9f8c0 100644 --- a/src/meta_schedule/postproc/rewrite_layout.cc +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include "../utils.h" @@ -29,19 +30,11 @@ namespace tir { */ class BufferReadPosCollector : public StmtExprVisitor { public: - explicit BufferReadPosCollector(const Array& buffers) { - for (const Buffer& buf : buffers) { - buffers_.insert(buf.get()); - } - } + explicit BufferReadPosCollector(const Buffer& buffer) : buffer_(buffer.get()) {} - const std::unordered_map>& GetBufferLocations() const { - return buffer_locs_; - } + const std::pair& GetBufferLocation() const { return buffer_loc_; } - const std::unordered_map>& GetBufferIndexMap() const { - return buffer_index_maps_; - } + const Optional GetBufferIndexMap() const { return buffer_index_map_; } private: void VisitStmt_(const ForNode* op) final { @@ -61,7 +54,7 @@ class BufferReadPosCollector : public StmtExprVisitor { CHECK(cur_realize_.defined()) << "BufferLoad occurred outside of any block"; const Buffer& buffer = op->buffer; - if (buffers_.count(buffer.get())) { + if (buffer_ == buffer.get()) { Map subst_map; for (size_t i = 0; i < cur_realize_->iter_values.size(); i++) { const Var& var = cur_realize_->block->iter_vars[i]->var; @@ -72,14 +65,14 @@ class BufferReadPosCollector : public StmtExprVisitor { for (const PrimExpr& e : op->indices) { subst_indices.push_back(Substitute(e, subst_map)); } - buffer_index_maps_[buffer.get()] = SuggestIndexMap(/*buffer=*/buffer, // - /*indices=*/subst_indices, // - /*loops=*/loop_stack_, // - /*predicate=*/cur_realize_->predicate, // - /*analyzer=*/&analyzer_); + buffer_index_map_ = SuggestIndexMap(/*buffer=*/buffer, // + /*indices=*/subst_indices, // + /*loops=*/loop_stack_, // + /*predicate=*/cur_realize_->predicate, // + /*analyzer=*/&analyzer_); int buffer_index = GetReadBufferIndex(cur_realize_->block, buffer); ICHECK(buffer_index != -1); - buffer_locs_[buffer.get()] = std::make_pair(cur_realize_->block, buffer_index); + buffer_loc_ = std::make_pair(cur_realize_->block, buffer_index); } } @@ -94,11 +87,11 @@ class BufferReadPosCollector : public StmtExprVisitor { private: /*! \brief All interested buffer. */ - std::unordered_set buffers_; + const BufferNode* buffer_; /*! \brief The result mapping from buffer to its inner-most block and read index. */ - std::unordered_map> buffer_locs_; + std::pair buffer_loc_; /*! \brief The result mapping from buffer to its IndexMap. */ - std::unordered_map> buffer_index_maps_; + Optional buffer_index_map_; /*! \brief Loop stack for calculating IndexMap. */ Array loop_stack_; @@ -143,8 +136,54 @@ Array CollectLayoutFreeBuffers(const PrimFuncNode* func) { return layout_free_buffers; } +std::optional> GetSuggestedIndexMap( + Buffer buffer, const PrimFuncNode* prim_func) { + BufferReadPosCollector collector(buffer); + collector(prim_func->body); + + const auto& index_map = collector.GetBufferIndexMap(); + + if (!index_map.defined() || !index_map) { + return std::nullopt; + } + + const auto& [anchor_block, buffer_index] = collector.GetBufferLocation(); + + return std::make_tuple(anchor_block, buffer_index, index_map.value()); +} + +std::vector GetCacheReadChain(const Buffer& buf, const PrimFuncNode* prim_func) { + class BufferReadChainCollector : public StmtVisitor { + public: + explicit BufferReadChainCollector(const Buffer& buffer) : cur_buffer_(buffer.get()) {} + + void VisitStmt_(const BlockNode* op) final { + if (!op->init && op->reads.size() == 1 && op->writes.size() == 1 && + op->reads[0]->buffer.get() == cur_buffer_) { + cache_read_chain.push_back(op->name_hint); + cur_buffer_ = op->writes[0]->buffer.get(); + } + StmtVisitor::VisitStmt_(op); + } + + std::vector cache_read_chain; + + private: + const BufferNode* cur_buffer_; + }; + + BufferReadChainCollector collector(buf); + collector(prim_func->body); + return collector.cache_read_chain; +} + bool RewriteLayout(const Schedule& sch) { std::vector> results; + auto add_layout_rewrite_block = [&sch](BlockRV consumer_block_rv, int buffer_index) { + BlockRV rewrite_block_rv = sch->CacheRead(consumer_block_rv, buffer_index, "global"); + sch->Annotate(rewrite_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true()); + }; + for (const auto& [g_var, base_func] : sch->mod()->functions) { const String& func_name = g_var->name_hint; const auto* prim_func = base_func.as(); @@ -153,36 +192,35 @@ bool RewriteLayout(const Schedule& sch) { continue; } - Array layout_free_buffers = CollectLayoutFreeBuffers(prim_func); - - // Collect Buffer read positions - BufferReadPosCollector collector(layout_free_buffers); - collector(prim_func->body); - const auto& locations = collector.GetBufferLocations(); - const auto& index_maps = collector.GetBufferIndexMap(); - // Check all buffers are collected - if (locations.size() != layout_free_buffers.size() || - index_maps.size() != layout_free_buffers.size()) { - return false; - } - - for (const auto& kv : locations) { - const Buffer& buffer = GetRef(kv.first); - const Block& block = kv.second.first; - int buffer_index = kv.second.second; - - // Get IndexMap - const Optional index_map = index_maps.at(buffer.get()); - if (!index_map.defined()) { - continue; + for (auto buffer : CollectLayoutFreeBuffers(prim_func)) { + const auto cache_read_chain = GetCacheReadChain(buffer, prim_func); + if (cache_read_chain.empty()) { + auto tup_opt = GetSuggestedIndexMap(buffer, prim_func); + if (tup_opt == std::nullopt) continue; + + auto [anchor_block, buffer_index, index_map] = *tup_opt; + auto anchor_block_rv = sch->GetBlock(anchor_block->name_hint, func_name); + add_layout_rewrite_block(anchor_block_rv, buffer_index); + sch->TransformLayout(anchor_block_rv, buffer_index, BufferIndexType::kRead, index_map, + NullOpt); + } else { + Block cache_read_block = sch->Get(sch->GetBlock(cache_read_chain.back(), func_name)); + ICHECK_EQ(cache_read_block->writes.size(), 1); + auto tup_opt = GetSuggestedIndexMap(cache_read_block->writes[0]->buffer, prim_func); + if (tup_opt == std::nullopt) continue; + + auto [anchor_block, buffer_index, index_map] = *tup_opt; + sch->TransformLayout(sch->GetBlock(anchor_block->name_hint, func_name), buffer_index, + BufferIndexType::kRead, index_map, NullOpt); + + for (int i = static_cast(cache_read_chain.size()) - 1; i >= 0; --i) { + BlockRV cache_read_block_rv = sch->GetBlock(cache_read_chain[i], func_name); + if (i == 0) { + add_layout_rewrite_block(cache_read_block_rv, 0); + } + sch->TransformLayout(cache_read_block_rv, 0, BufferIndexType::kRead, index_map, NullOpt); + } } - - // Apply schedule - BlockRV block_rv = sch->GetBlock(block->name_hint, func_name); - BlockRV cached_block_rv = sch->CacheRead(block_rv, buffer_index, "global"); - sch->TransformLayout(block_rv, buffer_index, BufferIndexType::kRead, index_map.value(), - NullOpt); - sch->Annotate(cached_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true()); } } return true; From 1a55c21e4cde1b188727f0e95aecbbf749320982 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 15 Nov 2022 08:12:00 +0900 Subject: [PATCH 2/9] add test --- ...t_meta_schedule_postproc_rewrite_layout.py | 153 +++++++++++++++++- 1 file changed, 152 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py index 91a51c8e9033..bff8cf0d9f95 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py @@ -204,5 +204,156 @@ def test_layout_rewrite(): tvm.ir.assert_structural_equal(sch.mod["main"], rewritten_tir_matmul) +# fmt: off +@tvm.script.ir_module +class Conv2dVTCMCacheRead: + @T.prim_func + def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(3, 3, 64, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float32"]): + T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + conv2d_nhwc_global = T.alloc_buffer([1, 56, 56, 64], dtype="float32") + pad_temp_global_vtcm = T.alloc_buffer([1, 58, 58, 64], dtype="float32", scope="global.vtcm") + p1_global_vtcm = T.alloc_buffer([3, 3, 64, 64], dtype="float32", scope="global.vtcm") + for i0_0_i1_0_i2_0_i3_0_fused in T.parallel(64, annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}): + for ax0_ax1_ax2_ax3_fused in T.serial(9216): + with T.block("pad_temp_global.vtcm"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + ax0_ax1_ax2_ax3_fused // 1024) + v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + ax0_ax1_ax2_ax3_fused % 1024 // 64) + v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) + T.reads(p0[v0, v1 - 1, v2 - 1, v3]) + T.writes(pad_temp_global_vtcm[v0, v1, v2, v3]) + pad_temp_global_vtcm[v0, v1, v2, v3] = T.if_then_else(1 <= v1 and v1 < 57 and 1 <= v2 and v2 < 57, p0[v0, v1 - 1, v2 - 1, v3], T.float32(0), dtype="float32") + for ax0_ax1_ax2_ax3_fused in T.serial(18432): + with T.block("p1_global.vtcm"): + v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 6144) + v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 6144 // 2048) + v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 2048 // 32) + v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32) + T.reads(p1[v0, v1, v2, v3]) + T.writes(p1_global_vtcm[v0, v1, v2, v3]) + p1_global_vtcm[v0, v1, v2, v3] = p1[v0, v1, v2, v3] + for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 1, 1, 1): + for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 7, 1, 16, 1, 1, 14): + for i3_3_fused_init in T.vectorized(2): + with T.block("conv2d_nhwc_init"): + nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + i1_1 * 7 + i1_2_init + i1_3_init) + xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + i2_1 * 14 + i2_2_init * 14 + i2_3_init) + ff = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + i3_1 * 32 + i3_2_init * 2 + i3_3_fused_init) + T.reads() + T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0) + for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 7, 1, 16, 3, 3, 32, 1, 1, 14): + for i3_3_fused in T.vectorized(2): + with T.block("conv2d_nhwc_update"): + nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + i1_1 * 7 + i1_2 + i1_3) + xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + i2_1 * 14 + i2_2 * 14 + i2_3) + ff = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + i3_1 * 32 + i3_2 * 2 + i3_3_fused) + ry = T.axis.reduce(3, i4_0 * 3 + i4_1) + rx = T.axis.reduce(3, i5_0 * 3 + i5_1) + rc = T.axis.reduce(64, i6_0 * 32 + i6_1) + T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global_vtcm[nn, yy + ry, xx + rx, rc], p1_global_vtcm[ry, rx, rc, ff]) + T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global_vtcm[nn, yy + ry, xx + rx, rc] * p1_global_vtcm[ry, rx, rc, ff] + for ax0, ax1, ax2 in T.grid(1, 7, 14): + for ax3_fused in T.vectorized(32): + with T.block("conv2d_nhwc_global"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + ax1) + v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + ax2) + v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + ax3_fused) + T.reads(conv2d_nhwc_global[v0, v1, v2, v3]) + T.writes(conv2d_nhwc[v0, v1, v2, v3]) + conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] + + +@tvm.script.ir_module +class Conv2dVTCMCacheReadRewritten: + @T.prim_func + def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(3, 3, 64, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float32"]): + T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + conv2d_nhwc_global = T.alloc_buffer([1, 56, 56, 64], dtype="float32") + pad_temp_global_vtcm = T.alloc_buffer([1, 58, 58, 64], dtype="float32", scope="global.vtcm") + p1_global_vtcm = T.alloc_buffer([2, 2, 16, 3, 3, 32, 2], dtype="float32", scope="global.vtcm") + p1_global = T.alloc_buffer([2, 2, 16, 3, 3, 32, 2], dtype="float32") + for ax0, ax1, ax2, ax3 in T.grid(3, 3, 64, 64): + with T.block("p1_global"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(p1[v0, v1, v2, v3]) + T.writes(p1_global[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2]) + T.block_attr({"meta_schedule.layout_rewrite_preproc":True}) + p1_global[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2] = p1[v0, v1, v2, v3] + for i0_0_i1_0_i2_0_i3_0_fused in T.parallel(64, annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}): + for ax0_ax1_ax2_ax3_fused in T.serial(9216): + with T.block("pad_temp_global.vtcm"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + ax0_ax1_ax2_ax3_fused // 1024) + v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + ax0_ax1_ax2_ax3_fused % 1024 // 64) + v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) + T.reads(p0[v0, v1 - 1, v2 - 1, v3]) + T.writes(pad_temp_global_vtcm[v0, v1, v2, v3]) + pad_temp_global_vtcm[v0, v1, v2, v3] = T.if_then_else(1 <= v1 and v1 < 57 and 1 <= v2 and v2 < 57, p0[v0, v1 - 1, v2 - 1, v3], T.float32(0), dtype="float32") + for ax0_ax1_ax2_ax3_fused in T.serial(18432): + with T.block("p1_global.vtcm"): + v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 6144) + v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 6144 // 2048) + v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 2048 // 32) + v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32) + T.reads(p1_global[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2]) + T.writes(p1_global_vtcm[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2]) + p1_global_vtcm[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2] = p1_global[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2] + for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 1, 1, 1): + for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 7, 1, 16, 1, 1, 14): + for i3_3_fused_init in T.vectorized(2): + with T.block("conv2d_nhwc_init"): + nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + i1_1 * 7 + i1_2_init + i1_3_init) + xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + i2_1 * 14 + i2_2_init * 14 + i2_3_init) + ff = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + i3_1 * 32 + i3_2_init * 2 + i3_3_fused_init) + T.reads() + T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0) + for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 7, 1, 16, 3, 3, 32, 1, 1, 14): + for i3_3_fused in T.vectorized(2): + with T.block("conv2d_nhwc_update"): + nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + i1_1 * 7 + i1_2 + i1_3) + xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + i2_1 * 14 + i2_2 * 14 + i2_3) + ff = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + i3_1 * 32 + i3_2 * 2 + i3_3_fused) + ry = T.axis.reduce(3, i4_0 * 3 + i4_1) + rx = T.axis.reduce(3, i5_0 * 3 + i5_1) + rc = T.axis.reduce(64, i6_0 * 32 + i6_1) + T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global_vtcm[nn, yy + ry, xx + rx, rc], p1_global_vtcm[ff // 32, rc // 32, ff % 32 // 2, ry, rx, rc % 32, ff % 2]) + T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global_vtcm[nn, yy + ry, xx + rx, rc] * p1_global_vtcm[ff // 32, rc // 32, ff % 32 // 2, ry, rx, rc % 32, ff % 2] + for ax0, ax1, ax2 in T.grid(1, 7, 14): + for ax3_fused in T.vectorized(32): + with T.block("conv2d_nhwc_global"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + ax1) + v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + ax2) + v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + ax3_fused) + T.reads(conv2d_nhwc_global[v0, v1, v2, v3]) + T.writes(conv2d_nhwc[v0, v1, v2, v3]) + conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] + + +# fmt: on + +def test_layout_rewrite_cache_read(): + target = Target("llvm") + ctx = _create_context(Conv2dVTCMCacheRead, target) + sch = tvm.tir.Schedule(Conv2dVTCMCacheRead, debug_mask="all") + sch.enter_postproc() + assert ctx.space_generator.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, Conv2dVTCMCacheReadRewritten) + + if __name__ == "__main__": - tvm.testing.main() + # tvm.testing.main() + test_layout_rewrite_cache_read() From b5f98f6a75cfa55bfed2a6d0637aee2ae5cfa0b9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 15 Nov 2022 08:22:03 +0900 Subject: [PATCH 3/9] add test for multiple cache read cast --- ...t_meta_schedule_postproc_rewrite_layout.py | 93 ++++++++++++++++++- 1 file changed, 91 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py index bff8cf0d9f95..c43be0bf5d26 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py @@ -343,6 +343,86 @@ def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(3, 3, 64, 64), conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] +# This is a contrived example to demonstrate layout rewrite propagating over multiple cache reads. +@tvm.script.ir_module +class Conv2dVTCMCacheReadMultipleRewritten: + @T.prim_func + def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(3, 3, 64, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float32"]): + T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + conv2d_nhwc_global = T.alloc_buffer([1, 56, 56, 64], dtype="float32") + pad_temp_global_vtcm = T.alloc_buffer([1, 58, 58, 64], dtype="float32", scope="global.vtcm") + p1_global_vtcm = T.alloc_buffer([2, 2, 16, 3, 3, 32, 2], dtype="float32", scope="global.vtcm") + p1_global_vtcm2 = T.alloc_buffer([2, 2, 16, 3, 3, 32, 2], dtype="float32", scope="global.vtcm2") + p1_global = T.alloc_buffer([2, 2, 16, 3, 3, 32, 2], dtype="float32") + for ax0, ax1, ax2, ax3 in T.grid(3, 3, 64, 64): + with T.block("p1_global"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(p1[v0, v1, v2, v3]) + T.writes(p1_global[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2]) + T.block_attr({"meta_schedule.layout_rewrite_preproc":True}) + p1_global[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2] = p1[v0, v1, v2, v3] + for ax0, ax1, ax2, ax3 in T.grid(3, 3, 64, 64): + with T.block("p1_global.vtcm2"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(p1_global[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2]) + T.writes(p1_global_vtcm2[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2]) + p1_global_vtcm2[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2] = p1_global[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2] + for i0_0_i1_0_i2_0_i3_0_fused in T.parallel(64, annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}): + for ax0_ax1_ax2_ax3_fused in T.serial(9216): + with T.block("pad_temp_global.vtcm"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + ax0_ax1_ax2_ax3_fused // 1024) + v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + ax0_ax1_ax2_ax3_fused % 1024 // 64) + v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) + T.reads(p0[v0, v1 - 1, v2 - 1, v3]) + T.writes(pad_temp_global_vtcm[v0, v1, v2, v3]) + pad_temp_global_vtcm[v0, v1, v2, v3] = T.if_then_else(1 <= v1 and v1 < 57 and 1 <= v2 and v2 < 57, p0[v0, v1 - 1, v2 - 1, v3], T.float32(0), dtype="float32") + for ax0_ax1_ax2_ax3_fused in T.serial(18432): + with T.block("p1_global.vtcm"): + v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 6144) + v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 6144 // 2048) + v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 2048 // 32) + v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32) + T.reads(p1_global_vtcm2[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2]) + T.writes(p1_global_vtcm[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2]) + p1_global_vtcm[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2] = p1_global_vtcm2[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2] + for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 1, 1, 1): + for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 7, 1, 16, 1, 1, 14): + for i3_3_fused_init in T.vectorized(2): + with T.block("conv2d_nhwc_init"): + nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + i1_1 * 7 + i1_2_init + i1_3_init) + xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + i2_1 * 14 + i2_2_init * 14 + i2_3_init) + ff = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + i3_1 * 32 + i3_2_init * 2 + i3_3_fused_init) + T.reads() + T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0) + for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 7, 1, 16, 3, 3, 32, 1, 1, 14): + for i3_3_fused in T.vectorized(2): + with T.block("conv2d_nhwc_update"): + nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + i1_1 * 7 + i1_2 + i1_3) + xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + i2_1 * 14 + i2_2 * 14 + i2_3) + ff = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + i3_1 * 32 + i3_2 * 2 + i3_3_fused) + ry = T.axis.reduce(3, i4_0 * 3 + i4_1) + rx = T.axis.reduce(3, i5_0 * 3 + i5_1) + rc = T.axis.reduce(64, i6_0 * 32 + i6_1) + T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global_vtcm[nn, yy + ry, xx + rx, rc], p1_global_vtcm[ff // 32, rc // 32, ff % 32 // 2, ry, rx, rc % 32, ff % 2]) + T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global_vtcm[nn, yy + ry, xx + rx, rc] * p1_global_vtcm[ff // 32, rc // 32, ff % 32 // 2, ry, rx, rc % 32, ff % 2] + for ax0, ax1, ax2 in T.grid(1, 7, 14): + for ax3_fused in T.vectorized(32): + with T.block("conv2d_nhwc_global"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + ax1) + v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + ax2) + v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + ax3_fused) + T.reads(conv2d_nhwc_global[v0, v1, v2, v3]) + T.writes(conv2d_nhwc[v0, v1, v2, v3]) + conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] + # fmt: on def test_layout_rewrite_cache_read(): @@ -354,6 +434,15 @@ def test_layout_rewrite_cache_read(): tvm.ir.assert_structural_equal(sch.mod, Conv2dVTCMCacheReadRewritten) +def test_layout_rewrite_cache_read_multiple(): + target = Target("llvm") + ctx = _create_context(Conv2dVTCMCacheRead, target) + sch = tvm.tir.Schedule(Conv2dVTCMCacheRead, debug_mask="all") + sch.cache_read(sch.get_block("p1_global.vtcm"), 0, "global.vtcm2") + sch.enter_postproc() + assert ctx.space_generator.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, Conv2dVTCMCacheReadMultipleRewritten) + + if __name__ == "__main__": - # tvm.testing.main() - test_layout_rewrite_cache_read() + tvm.testing.main() From 0b71c1a438b9f2018e01b835a000b828452c1040 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 15 Nov 2022 08:37:30 +0900 Subject: [PATCH 4/9] do not use vtcm --- .../measure_callback/update_cost_model.cc | 1 + ...t_meta_schedule_postproc_rewrite_layout.py | 425 ++++++++++-------- 2 files changed, 231 insertions(+), 195 deletions(-) diff --git a/src/meta_schedule/measure_callback/update_cost_model.cc b/src/meta_schedule/measure_callback/update_cost_model.cc index 8a8a43658409..ff9f7b953caf 100644 --- a/src/meta_schedule/measure_callback/update_cost_model.cc +++ b/src/meta_schedule/measure_callback/update_cost_model.cc @@ -42,6 +42,7 @@ class UpdateCostModelNode : public MeasureCallbackNode { pruned_candidate.reserve(n); pruned_runner_result.reserve(n); for (int i = 0; i < n; i++) { + ICHECK(runner_results[i]->run_secs); if (!builder_results[i]->error_msg.defined() && Sum(runner_results[i]->run_secs.value()) > 0) { pruned_candidate.push_back(measure_candidates[i]); diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py index c43be0bf5d26..1e0606df016f 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py @@ -206,242 +206,277 @@ def test_layout_rewrite(): # fmt: off @tvm.script.ir_module -class Conv2dVTCMCacheRead: +class Conv2dCacheRead: @T.prim_func def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(3, 3, 64, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float32"]): T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + pad_temp = T.alloc_buffer([1, 58, 58, 64], dtype="float32") conv2d_nhwc_global = T.alloc_buffer([1, 56, 56, 64], dtype="float32") - pad_temp_global_vtcm = T.alloc_buffer([1, 58, 58, 64], dtype="float32", scope="global.vtcm") - p1_global_vtcm = T.alloc_buffer([3, 3, 64, 64], dtype="float32", scope="global.vtcm") - for i0_0_i1_0_i2_0_i3_0_fused in T.parallel(64, annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}): - for ax0_ax1_ax2_ax3_fused in T.serial(9216): - with T.block("pad_temp_global.vtcm"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + ax0_ax1_ax2_ax3_fused // 1024) - v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + ax0_ax1_ax2_ax3_fused % 1024 // 64) - v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) - T.reads(p0[v0, v1 - 1, v2 - 1, v3]) - T.writes(pad_temp_global_vtcm[v0, v1, v2, v3]) - pad_temp_global_vtcm[v0, v1, v2, v3] = T.if_then_else(1 <= v1 and v1 < 57 and 1 <= v2 and v2 < 57, p0[v0, v1 - 1, v2 - 1, v3], T.float32(0), dtype="float32") - for ax0_ax1_ax2_ax3_fused in T.serial(18432): - with T.block("p1_global.vtcm"): - v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 6144) - v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 6144 // 2048) - v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 2048 // 32) - v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32) - T.reads(p1[v0, v1, v2, v3]) - T.writes(p1_global_vtcm[v0, v1, v2, v3]) - p1_global_vtcm[v0, v1, v2, v3] = p1[v0, v1, v2, v3] - for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 1, 1, 1): - for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 7, 1, 16, 1, 1, 14): - for i3_3_fused_init in T.vectorized(2): - with T.block("conv2d_nhwc_init"): - nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1) - yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + i1_1 * 7 + i1_2_init + i1_3_init) - xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + i2_1 * 14 + i2_2_init * 14 + i2_3_init) - ff = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + i3_1 * 32 + i3_2_init * 2 + i3_3_fused_init) - T.reads() - T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) - conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0) - for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 7, 1, 16, 3, 3, 32, 1, 1, 14): - for i3_3_fused in T.vectorized(2): - with T.block("conv2d_nhwc_update"): - nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1) - yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + i1_1 * 7 + i1_2 + i1_3) - xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + i2_1 * 14 + i2_2 * 14 + i2_3) - ff = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + i3_1 * 32 + i3_2 * 2 + i3_3_fused) - ry = T.axis.reduce(3, i4_0 * 3 + i4_1) - rx = T.axis.reduce(3, i5_0 * 3 + i5_1) - rc = T.axis.reduce(64, i6_0 * 32 + i6_1) - T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global_vtcm[nn, yy + ry, xx + rx, rc], p1_global_vtcm[ry, rx, rc, ff]) - T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) - conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global_vtcm[nn, yy + ry, xx + rx, rc] * p1_global_vtcm[ry, rx, rc, ff] - for ax0, ax1, ax2 in T.grid(1, 7, 14): - for ax3_fused in T.vectorized(32): - with T.block("conv2d_nhwc_global"): - v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + ax1) - v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + ax2) - v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + ax3_fused) - T.reads(conv2d_nhwc_global[v0, v1, v2, v3]) - T.writes(conv2d_nhwc[v0, v1, v2, v3]) - conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] + pad_temp_global = T.alloc_buffer([1, 58, 58, 64], dtype="float32") + p1_global = T.alloc_buffer([3, 3, 64, 64], dtype="float32") + for i0_0_i1_0_i2_0_fused in T.parallel(4, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}): + for ax0, ax1, ax2 in T.grid(1, 30, 30): + for ax3_fused in T.vectorized(64): + with T.block("pad_temp"): + i0 = T.axis.spatial(1, ax0) + i1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax1) + i2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax2) + i3 = T.axis.spatial(64, ax3_fused) + T.reads(p0[i0, i1 - 1, i2 - 1, i3]) + T.writes(pad_temp[i0, i1, i2, i3]) + pad_temp[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 57 and 1 <= i2 and i2 < 57, p0[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32") + for i3_0 in T.serial(16): + for ax0_ax1_ax2_ax3_fused in T.serial(57600): + with T.block("pad_temp_global"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax0_ax1_ax2_ax3_fused // 1920) + v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax0_ax1_ax2_ax3_fused % 1920 // 64) + v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) + T.reads(pad_temp[v0, v1, v2, v3]) + T.writes(pad_temp_global[v0, v1, v2, v3]) + pad_temp_global[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused in T.serial(2304): + with T.block("p1_global"): + v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 768) + v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 768 // 256) + v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 256 // 4) + v3 = T.axis.spatial(64, i3_0 * 4 + ax0_ax1_ax2_ax3_fused % 4) + T.reads(p1[v0, v1, v2, v3]) + T.writes(p1_global[v0, v1, v2, v3]) + p1_global[v0, v1, v2, v3] = p1[v0, v1, v2, v3] + for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 7, 2, 1): + for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 14, 2, 1, 4, 1): + for i3_3_fused_init in T.vectorized(2): + with T.block("conv2d_nhwc_init"): + nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2_init * 4 + i1_3_init) + xx = T.axis.spatial(56, i2_3_init + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2_init) + ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2_init * 2 + i3_3_fused_init) + T.reads() + T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0) + for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 14, 2, 3, 3, 32, 1, 4, 1): + for i3_3_fused in T.vectorized(2): + with T.block("conv2d_nhwc_update"): + nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2 * 4 + i1_3) + xx = T.axis.spatial(56, i2_3 + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2) + ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2 * 2 + i3_3_fused) + ry = T.axis.reduce(3, i4_0 * 3 + i4_1) + rx = T.axis.reduce(3, i5_0 * 3 + i5_1) + rc = T.axis.reduce(64, i6_0 * 32 + i6_1) + T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global[nn, yy + ry, xx + rx, rc], p1_global[ry, rx, rc, ff]) + T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global[nn, yy + ry, xx + rx, rc] * p1_global[ry, rx, rc, ff] + for ax0, ax1, ax2 in T.grid(1, 4, 14): + for ax3_fused in T.vectorized(4): + with T.block("conv2d_nhwc_global"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + ax1) + v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + ax2) + v3 = T.axis.spatial(64, i3_0 * 4 + ax3_fused) + T.reads(conv2d_nhwc_global[v0, v1, v2, v3]) + T.writes(conv2d_nhwc[v0, v1, v2, v3]) + conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] @tvm.script.ir_module -class Conv2dVTCMCacheReadRewritten: +class Conv2dCacheReadRewritten: @T.prim_func def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(3, 3, 64, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float32"]): T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + pad_temp = T.alloc_buffer([1, 58, 58, 64], dtype="float32") conv2d_nhwc_global = T.alloc_buffer([1, 56, 56, 64], dtype="float32") - pad_temp_global_vtcm = T.alloc_buffer([1, 58, 58, 64], dtype="float32", scope="global.vtcm") - p1_global_vtcm = T.alloc_buffer([2, 2, 16, 3, 3, 32, 2], dtype="float32", scope="global.vtcm") - p1_global = T.alloc_buffer([2, 2, 16, 3, 3, 32, 2], dtype="float32") + pad_temp_global = T.alloc_buffer([1, 58, 58, 64], dtype="float32") + p1_global = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32") + p1_global_1 = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32") for ax0, ax1, ax2, ax3 in T.grid(3, 3, 64, 64): with T.block("p1_global"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(p1[v0, v1, v2, v3]) - T.writes(p1_global[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2]) + T.writes(p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) T.block_attr({"meta_schedule.layout_rewrite_preproc":True}) - p1_global[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2] = p1[v0, v1, v2, v3] - for i0_0_i1_0_i2_0_i3_0_fused in T.parallel(64, annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}): - for ax0_ax1_ax2_ax3_fused in T.serial(9216): - with T.block("pad_temp_global.vtcm"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + ax0_ax1_ax2_ax3_fused // 1024) - v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + ax0_ax1_ax2_ax3_fused % 1024 // 64) - v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) - T.reads(p0[v0, v1 - 1, v2 - 1, v3]) - T.writes(pad_temp_global_vtcm[v0, v1, v2, v3]) - pad_temp_global_vtcm[v0, v1, v2, v3] = T.if_then_else(1 <= v1 and v1 < 57 and 1 <= v2 and v2 < 57, p0[v0, v1 - 1, v2 - 1, v3], T.float32(0), dtype="float32") - for ax0_ax1_ax2_ax3_fused in T.serial(18432): - with T.block("p1_global.vtcm"): - v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 6144) - v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 6144 // 2048) - v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 2048 // 32) - v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32) - T.reads(p1_global[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2]) - T.writes(p1_global_vtcm[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2]) - p1_global_vtcm[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2] = p1_global[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2] - for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 1, 1, 1): - for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 7, 1, 16, 1, 1, 14): - for i3_3_fused_init in T.vectorized(2): - with T.block("conv2d_nhwc_init"): - nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1) - yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + i1_1 * 7 + i1_2_init + i1_3_init) - xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + i2_1 * 14 + i2_2_init * 14 + i2_3_init) - ff = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + i3_1 * 32 + i3_2_init * 2 + i3_3_fused_init) - T.reads() - T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) - conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0) - for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 7, 1, 16, 3, 3, 32, 1, 1, 14): - for i3_3_fused in T.vectorized(2): - with T.block("conv2d_nhwc_update"): - nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1) - yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + i1_1 * 7 + i1_2 + i1_3) - xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + i2_1 * 14 + i2_2 * 14 + i2_3) - ff = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + i3_1 * 32 + i3_2 * 2 + i3_3_fused) - ry = T.axis.reduce(3, i4_0 * 3 + i4_1) - rx = T.axis.reduce(3, i5_0 * 3 + i5_1) - rc = T.axis.reduce(64, i6_0 * 32 + i6_1) - T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global_vtcm[nn, yy + ry, xx + rx, rc], p1_global_vtcm[ff // 32, rc // 32, ff % 32 // 2, ry, rx, rc % 32, ff % 2]) - T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) - conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global_vtcm[nn, yy + ry, xx + rx, rc] * p1_global_vtcm[ff // 32, rc // 32, ff % 32 // 2, ry, rx, rc % 32, ff % 2] - for ax0, ax1, ax2 in T.grid(1, 7, 14): - for ax3_fused in T.vectorized(32): - with T.block("conv2d_nhwc_global"): - v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + ax1) - v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + ax2) - v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + ax3_fused) - T.reads(conv2d_nhwc_global[v0, v1, v2, v3]) - T.writes(conv2d_nhwc[v0, v1, v2, v3]) - conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] - - -# This is a contrived example to demonstrate layout rewrite propagating over multiple cache reads. + p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1[v0, v1, v2, v3] + for i0_0_i1_0_i2_0_fused in T.parallel(4, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}): + for ax0, ax1, ax2 in T.grid(1, 30, 30): + for ax3_fused in T.vectorized(64): + with T.block("pad_temp"): + i0 = T.axis.spatial(1, ax0) + i1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax1) + i2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax2) + i3 = T.axis.spatial(64, ax3_fused) + T.reads(p0[i0, i1 - 1, i2 - 1, i3]) + T.writes(pad_temp[i0, i1, i2, i3]) + pad_temp[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 57 and 1 <= i2 and i2 < 57, p0[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32") + for i3_0 in T.serial(16): + for ax0_ax1_ax2_ax3_fused in T.serial(57600): + with T.block("pad_temp_global"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax0_ax1_ax2_ax3_fused // 1920) + v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax0_ax1_ax2_ax3_fused % 1920 // 64) + v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) + T.reads(pad_temp[v0, v1, v2, v3]) + T.writes(pad_temp_global[v0, v1, v2, v3]) + pad_temp_global[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused in T.serial(2304): + with T.block("p1_global"): + v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 768) + v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 768 // 256) + v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 256 // 4) + v3 = T.axis.spatial(64, i3_0 * 4 + ax0_ax1_ax2_ax3_fused % 4) + T.reads(p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) + T.writes(p1_global[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) + p1_global[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] + for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 7, 2, 1): + for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 14, 2, 1, 4, 1): + for i3_3_fused_init in T.vectorized(2): + with T.block("conv2d_nhwc_init"): + nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2_init * 4 + i1_3_init) + xx = T.axis.spatial(56, i2_3_init + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2_init) + ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2_init * 2 + i3_3_fused_init) + T.reads() + T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0) + for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 14, 2, 3, 3, 32, 1, 4, 1): + for i3_3_fused in T.vectorized(2): + with T.block("conv2d_nhwc_update"): + nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2 * 4 + i1_3) + xx = T.axis.spatial(56, i2_3 + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2) + ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2 * 2 + i3_3_fused) + ry = T.axis.reduce(3, i4_0 * 3 + i4_1) + rx = T.axis.reduce(3, i5_0 * 3 + i5_1) + rc = T.axis.reduce(64, i6_0 * 32 + i6_1) + T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global[nn, yy + ry, xx + rx, rc], p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2]) + T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global[nn, yy + ry, xx + rx, rc] * p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2] + for ax0, ax1, ax2 in T.grid(1, 4, 14): + for ax3_fused in T.vectorized(4): + with T.block("conv2d_nhwc_global"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + ax1) + v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + ax2) + v3 = T.axis.spatial(64, i3_0 * 4 + ax3_fused) + T.reads(conv2d_nhwc_global[v0, v1, v2, v3]) + T.writes(conv2d_nhwc[v0, v1, v2, v3]) + conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] + + @tvm.script.ir_module -class Conv2dVTCMCacheReadMultipleRewritten: +class Conv2dCacheReadMultipleRewritten: @T.prim_func def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(3, 3, 64, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float32"]): T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + pad_temp = T.alloc_buffer([1, 58, 58, 64], dtype="float32") conv2d_nhwc_global = T.alloc_buffer([1, 56, 56, 64], dtype="float32") - pad_temp_global_vtcm = T.alloc_buffer([1, 58, 58, 64], dtype="float32", scope="global.vtcm") - p1_global_vtcm = T.alloc_buffer([2, 2, 16, 3, 3, 32, 2], dtype="float32", scope="global.vtcm") - p1_global_vtcm2 = T.alloc_buffer([2, 2, 16, 3, 3, 32, 2], dtype="float32", scope="global.vtcm2") - p1_global = T.alloc_buffer([2, 2, 16, 3, 3, 32, 2], dtype="float32") + pad_temp_global = T.alloc_buffer([1, 58, 58, 64], dtype="float32") + p1_global = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32") + p1_global2 = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32", scope="global2") + p1_global_1 = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32") for ax0, ax1, ax2, ax3 in T.grid(3, 3, 64, 64): with T.block("p1_global"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(p1[v0, v1, v2, v3]) - T.writes(p1_global[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2]) + T.writes(p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) T.block_attr({"meta_schedule.layout_rewrite_preproc":True}) - p1_global[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2] = p1[v0, v1, v2, v3] + p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1[v0, v1, v2, v3] for ax0, ax1, ax2, ax3 in T.grid(3, 3, 64, 64): - with T.block("p1_global.vtcm2"): + with T.block("p1_global2"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(p1_global[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2]) - T.writes(p1_global_vtcm2[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2]) - p1_global_vtcm2[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2] = p1_global[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2] - for i0_0_i1_0_i2_0_i3_0_fused in T.parallel(64, annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}): - for ax0_ax1_ax2_ax3_fused in T.serial(9216): - with T.block("pad_temp_global.vtcm"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + ax0_ax1_ax2_ax3_fused // 1024) - v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + ax0_ax1_ax2_ax3_fused % 1024 // 64) - v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) - T.reads(p0[v0, v1 - 1, v2 - 1, v3]) - T.writes(pad_temp_global_vtcm[v0, v1, v2, v3]) - pad_temp_global_vtcm[v0, v1, v2, v3] = T.if_then_else(1 <= v1 and v1 < 57 and 1 <= v2 and v2 < 57, p0[v0, v1 - 1, v2 - 1, v3], T.float32(0), dtype="float32") - for ax0_ax1_ax2_ax3_fused in T.serial(18432): - with T.block("p1_global.vtcm"): - v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 6144) - v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 6144 // 2048) - v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 2048 // 32) - v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32) - T.reads(p1_global_vtcm2[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2]) - T.writes(p1_global_vtcm[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2]) - p1_global_vtcm[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2] = p1_global_vtcm2[v3 // 32, v2 // 32, v3 % 32 // 2, v0, v1, v2 % 32, v3 % 2] - for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 1, 1, 1): - for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 7, 1, 16, 1, 1, 14): - for i3_3_fused_init in T.vectorized(2): - with T.block("conv2d_nhwc_init"): - nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1) - yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + i1_1 * 7 + i1_2_init + i1_3_init) - xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + i2_1 * 14 + i2_2_init * 14 + i2_3_init) - ff = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + i3_1 * 32 + i3_2_init * 2 + i3_3_fused_init) - T.reads() - T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) - conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0) - for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 7, 1, 16, 3, 3, 32, 1, 1, 14): - for i3_3_fused in T.vectorized(2): - with T.block("conv2d_nhwc_update"): - nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1) - yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + i1_1 * 7 + i1_2 + i1_3) - xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + i2_1 * 14 + i2_2 * 14 + i2_3) - ff = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + i3_1 * 32 + i3_2 * 2 + i3_3_fused) - ry = T.axis.reduce(3, i4_0 * 3 + i4_1) - rx = T.axis.reduce(3, i5_0 * 3 + i5_1) - rc = T.axis.reduce(64, i6_0 * 32 + i6_1) - T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global_vtcm[nn, yy + ry, xx + rx, rc], p1_global_vtcm[ff // 32, rc // 32, ff % 32 // 2, ry, rx, rc % 32, ff % 2]) - T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) - T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) - conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global_vtcm[nn, yy + ry, xx + rx, rc] * p1_global_vtcm[ff // 32, rc // 32, ff % 32 // 2, ry, rx, rc % 32, ff % 2] - for ax0, ax1, ax2 in T.grid(1, 7, 14): - for ax3_fused in T.vectorized(32): - with T.block("conv2d_nhwc_global"): - v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 8 * 7 + ax1) - v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 8 // 2 * 14 + ax2) - v3 = T.axis.spatial(64, i0_0_i1_0_i2_0_i3_0_fused % 2 * 32 + ax3_fused) - T.reads(conv2d_nhwc_global[v0, v1, v2, v3]) - T.writes(conv2d_nhwc[v0, v1, v2, v3]) - conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] + T.reads(p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) + T.writes(p1_global2[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) + p1_global2[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] + for i0_0_i1_0_i2_0_fused in T.parallel(4, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}): + for ax0, ax1, ax2 in T.grid(1, 30, 30): + for ax3_fused in T.vectorized(64): + with T.block("pad_temp"): + i0 = T.axis.spatial(1, ax0) + i1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax1) + i2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax2) + i3 = T.axis.spatial(64, ax3_fused) + T.reads(p0[i0, i1 - 1, i2 - 1, i3]) + T.writes(pad_temp[i0, i1, i2, i3]) + pad_temp[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 57 and 1 <= i2 and i2 < 57, p0[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32") + for i3_0 in T.serial(16): + for ax0_ax1_ax2_ax3_fused in T.serial(57600): + with T.block("pad_temp_global"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax0_ax1_ax2_ax3_fused // 1920) + v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax0_ax1_ax2_ax3_fused % 1920 // 64) + v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) + T.reads(pad_temp[v0, v1, v2, v3]) + T.writes(pad_temp_global[v0, v1, v2, v3]) + pad_temp_global[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused in T.serial(2304): + with T.block("p1_global"): + v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 768) + v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 768 // 256) + v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 256 // 4) + v3 = T.axis.spatial(64, i3_0 * 4 + ax0_ax1_ax2_ax3_fused % 4) + T.reads(p1_global2[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) + T.writes(p1_global[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) + p1_global[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1_global2[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] + for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 7, 2, 1): + for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 14, 2, 1, 4, 1): + for i3_3_fused_init in T.vectorized(2): + with T.block("conv2d_nhwc_init"): + nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2_init * 4 + i1_3_init) + xx = T.axis.spatial(56, i2_3_init + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2_init) + ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2_init * 2 + i3_3_fused_init) + T.reads() + T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0) + for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 14, 2, 3, 3, 32, 1, 4, 1): + for i3_3_fused in T.vectorized(2): + with T.block("conv2d_nhwc_update"): + nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2 * 4 + i1_3) + xx = T.axis.spatial(56, i2_3 + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2) + ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2 * 2 + i3_3_fused) + ry = T.axis.reduce(3, i4_0 * 3 + i4_1) + rx = T.axis.reduce(3, i5_0 * 3 + i5_1) + rc = T.axis.reduce(64, i6_0 * 32 + i6_1) + T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global[nn, yy + ry, xx + rx, rc], p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2]) + T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global[nn, yy + ry, xx + rx, rc] * p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2] + for ax0, ax1, ax2 in T.grid(1, 4, 14): + for ax3_fused in T.vectorized(4): + with T.block("conv2d_nhwc_global"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + ax1) + v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + ax2) + v3 = T.axis.spatial(64, i3_0 * 4 + ax3_fused) + T.reads(conv2d_nhwc_global[v0, v1, v2, v3]) + T.writes(conv2d_nhwc[v0, v1, v2, v3]) + conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] # fmt: on def test_layout_rewrite_cache_read(): target = Target("llvm") - ctx = _create_context(Conv2dVTCMCacheRead, target) - sch = tvm.tir.Schedule(Conv2dVTCMCacheRead, debug_mask="all") + ctx = _create_context(Conv2dCacheRead, target) + sch = tvm.tir.Schedule(Conv2dCacheRead, debug_mask="all") sch.enter_postproc() assert ctx.space_generator.postprocs[0].apply(sch) - tvm.ir.assert_structural_equal(sch.mod, Conv2dVTCMCacheReadRewritten) + tvm.ir.assert_structural_equal(sch.mod, Conv2dCacheReadRewritten) def test_layout_rewrite_cache_read_multiple(): target = Target("llvm") - ctx = _create_context(Conv2dVTCMCacheRead, target) - sch = tvm.tir.Schedule(Conv2dVTCMCacheRead, debug_mask="all") - sch.cache_read(sch.get_block("p1_global.vtcm"), 0, "global.vtcm2") + ctx = _create_context(Conv2dCacheRead, target) + sch = tvm.tir.Schedule(Conv2dCacheRead, debug_mask="all") + sch.cache_read(sch.get_block("p1_global"), 0, "global2") sch.enter_postproc() assert ctx.space_generator.postprocs[0].apply(sch) - tvm.ir.assert_structural_equal(sch.mod, Conv2dVTCMCacheReadMultipleRewritten) + tvm.ir.assert_structural_equal(sch.mod, Conv2dCacheReadMultipleRewritten) if __name__ == "__main__": From ad042429d0f0a28c0282fc6c335a69337dc02262 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 15 Nov 2022 09:46:07 +0900 Subject: [PATCH 5/9] add comment --- src/meta_schedule/postproc/rewrite_layout.cc | 21 ++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc index 0e4e1fe9f8c0..b4f496ae8627 100644 --- a/src/meta_schedule/postproc/rewrite_layout.cc +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -26,7 +26,7 @@ namespace tir { /*! * \brief Collect the block and index where the buffer is read. - * \note The buffers are expected to be read by only one BufferLoad + * \note The buffer is expected to be read by only one BufferLoad */ class BufferReadPosCollector : public StmtExprVisitor { public: @@ -86,11 +86,11 @@ class BufferReadPosCollector : public StmtExprVisitor { } private: - /*! \brief All interested buffer. */ + /*! \brief The buffer of interest. */ const BufferNode* buffer_; - /*! \brief The result mapping from buffer to its inner-most block and read index. */ + /*! \brief The block that consumes the buffer and the corresponding read index. */ std::pair buffer_loc_; - /*! \brief The result mapping from buffer to its IndexMap. */ + /*! \brief The proposed IndexMap. */ Optional buffer_index_map_; /*! \brief Loop stack for calculating IndexMap. */ @@ -152,12 +152,14 @@ std::optional> GetSuggestedIndexMap( return std::make_tuple(anchor_block, buffer_index, index_map.value()); } +/*! \brief Get a chain of cache-read blocks, starting from the one consuming buf. */ std::vector GetCacheReadChain(const Buffer& buf, const PrimFuncNode* prim_func) { class BufferReadChainCollector : public StmtVisitor { public: explicit BufferReadChainCollector(const Buffer& buffer) : cur_buffer_(buffer.get()) {} void VisitStmt_(const BlockNode* op) final { + // Check if this block is doing cache_read or a similar operation that consumes cur_buffer_. if (!op->init && op->reads.size() == 1 && op->writes.size() == 1 && op->reads[0]->buffer.get() == cur_buffer_) { cache_read_chain.push_back(op->name_hint); @@ -195,6 +197,8 @@ bool RewriteLayout(const Schedule& sch) { for (auto buffer : CollectLayoutFreeBuffers(prim_func)) { const auto cache_read_chain = GetCacheReadChain(buffer, prim_func); if (cache_read_chain.empty()) { + // The common case, where the layout-free buffer is directly consumed by an anchor op such + // as conv2d or dense. auto tup_opt = GetSuggestedIndexMap(buffer, prim_func); if (tup_opt == std::nullopt) continue; @@ -204,18 +208,27 @@ bool RewriteLayout(const Schedule& sch) { sch->TransformLayout(anchor_block_rv, buffer_index, BufferIndexType::kRead, index_map, NullOpt); } else { + // When the layout-free buffer is consumed by cache_read, we need to find the index map + // for a cache-read buffer that is directly consumed by an anchor op. The last buffer + // in cache_read_chain corresponds to that buffer. Block cache_read_block = sch->Get(sch->GetBlock(cache_read_chain.back(), func_name)); ICHECK_EQ(cache_read_block->writes.size(), 1); auto tup_opt = GetSuggestedIndexMap(cache_read_block->writes[0]->buffer, prim_func); if (tup_opt == std::nullopt) continue; auto [anchor_block, buffer_index, index_map] = *tup_opt; + // Transform the layout of the last cache-read buffer. sch->TransformLayout(sch->GetBlock(anchor_block->name_hint, func_name), buffer_index, BufferIndexType::kRead, index_map, NullOpt); + // Propagate the layout transformation over cache_read_chain, starting from + // the next-to-last cache-read buffer. for (int i = static_cast(cache_read_chain.size()) - 1; i >= 0; --i) { BlockRV cache_read_block_rv = sch->GetBlock(cache_read_chain[i], func_name); if (i == 0) { + // Before the first cache_read that consumes the layout-free buffer, insert + // a layout-rewrite block. Another cache read buffer is added, and its layout is + // transformed by TransformLayout below. add_layout_rewrite_block(cache_read_block_rv, 0); } sch->TransformLayout(cache_read_block_rv, 0, BufferIndexType::kRead, index_map, NullOpt); From 29aa4ee19f3dc7e14e2f8c521b5dd8d4168b2a94 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 15 Nov 2022 09:58:24 +0900 Subject: [PATCH 6/9] black --- .../unittest/test_meta_schedule_postproc_rewrite_layout.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py index 1e0606df016f..98c1f7368580 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py @@ -460,6 +460,7 @@ def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(3, 3, 64, 64), # fmt: on + def test_layout_rewrite_cache_read(): target = Target("llvm") ctx = _create_context(Conv2dCacheRead, target) From e91f675e946475f2b98e6166eaf0e2036be6f2aa Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 15 Nov 2022 11:05:26 +0900 Subject: [PATCH 7/9] Fix CreatePrimFunc for link-params=True case --- src/meta_schedule/postproc/rewrite_layout.cc | 2 +- src/te/operation/create_primfunc.cc | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc index b4f496ae8627..71ae43387112 100644 --- a/src/meta_schedule/postproc/rewrite_layout.cc +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -227,7 +227,7 @@ bool RewriteLayout(const Schedule& sch) { BlockRV cache_read_block_rv = sch->GetBlock(cache_read_chain[i], func_name); if (i == 0) { // Before the first cache_read that consumes the layout-free buffer, insert - // a layout-rewrite block. Another cache read buffer is added, and its layout is + // a layout-rewrite block. Another cache-read buffer is added, and its layout is // transformed by TransformLayout below. add_layout_rewrite_block(cache_read_block_rv, 0); } diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 80da5a727926..0581ad60e8f4 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -110,13 +110,20 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { Block block = Downcast(StmtMutator::VisitStmt_(_block)); BlockNode* n = block.CopyOnWrite(); if (Optional ann = n->annotations.Get(topi_attr)) { + Array new_buffers; for (Buffer buffer : Downcast>(ann)) { auto it = buffer2index_.find(buffer); if (it != buffer2index_.end()) { layout_free_buffer_indices_.insert(it->second); + } else { + new_buffers.push_back(buffer); } } - n->annotations.erase(topi_attr); + if (new_buffers.empty()) { + n->annotations.erase(topi_attr); + } else { + n->annotations.Set(topi_attr, new_buffers); + } } for (const String& attr : this->blocklist) { auto it = n->annotations.find(attr); From 405a3d4799ef3cae66b13dea46fc93d32f4c03df Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 15 Nov 2022 16:37:18 +0900 Subject: [PATCH 8/9] fix te create_primfunc test --- tests/python/unittest/test_te_create_primfunc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index b59880758e5d..7b8173d0b2d9 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -390,6 +390,7 @@ def expected_layout_attr( C[x, y] = C[x, y] + A[x, k] * B[y, k] for i0, i1 in T.grid(128, 128): with T.block("D"): + T.block_attr({"layout_free_placeholders": [C]}) x, y = T.axis.remap("SS", [i0, i1]) D[x, y] = C[x, y] + T.float32(1) From 30199c3ec6ce75cfe5ec0149c8c5363d71725feb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 15 Nov 2022 19:34:22 +0900 Subject: [PATCH 9/9] revert change in update_cost_model.cc --- src/meta_schedule/measure_callback/update_cost_model.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/meta_schedule/measure_callback/update_cost_model.cc b/src/meta_schedule/measure_callback/update_cost_model.cc index ff9f7b953caf..8a8a43658409 100644 --- a/src/meta_schedule/measure_callback/update_cost_model.cc +++ b/src/meta_schedule/measure_callback/update_cost_model.cc @@ -42,7 +42,6 @@ class UpdateCostModelNode : public MeasureCallbackNode { pruned_candidate.reserve(n); pruned_runner_result.reserve(n); for (int i = 0; i < n; i++) { - ICHECK(runner_results[i]->run_secs); if (!builder_results[i]->error_msg.defined() && Sum(runner_results[i]->run_secs.value()) > 0) { pruned_candidate.push_back(measure_candidates[i]);