diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc index 3aed6680e30d..71ae43387112 100644 --- a/src/meta_schedule/postproc/rewrite_layout.cc +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include "../utils.h" @@ -25,23 +26,15 @@ namespace tir { /*! * \brief Collect the block and index where the buffer is read. - * \note The buffers are expected to be read by only one BufferLoad + * \note The buffer is expected to be read by only one BufferLoad */ class BufferReadPosCollector : public StmtExprVisitor { public: - explicit BufferReadPosCollector(const Array& buffers) { - for (const Buffer& buf : buffers) { - buffers_.insert(buf.get()); - } - } + explicit BufferReadPosCollector(const Buffer& buffer) : buffer_(buffer.get()) {} - const std::unordered_map>& GetBufferLocations() const { - return buffer_locs_; - } + const std::pair& GetBufferLocation() const { return buffer_loc_; } - const std::unordered_map>& GetBufferIndexMap() const { - return buffer_index_maps_; - } + const Optional GetBufferIndexMap() const { return buffer_index_map_; } private: void VisitStmt_(const ForNode* op) final { @@ -61,7 +54,7 @@ class BufferReadPosCollector : public StmtExprVisitor { CHECK(cur_realize_.defined()) << "BufferLoad occurred outside of any block"; const Buffer& buffer = op->buffer; - if (buffers_.count(buffer.get())) { + if (buffer_ == buffer.get()) { Map subst_map; for (size_t i = 0; i < cur_realize_->iter_values.size(); i++) { const Var& var = cur_realize_->block->iter_vars[i]->var; @@ -72,14 +65,14 @@ class BufferReadPosCollector : public StmtExprVisitor { for (const PrimExpr& e : op->indices) { subst_indices.push_back(Substitute(e, subst_map)); } - buffer_index_maps_[buffer.get()] = SuggestIndexMap(/*buffer=*/buffer, // - /*indices=*/subst_indices, // - /*loops=*/loop_stack_, // - /*predicate=*/cur_realize_->predicate, // - /*analyzer=*/&analyzer_); + buffer_index_map_ = SuggestIndexMap(/*buffer=*/buffer, // + /*indices=*/subst_indices, // + /*loops=*/loop_stack_, // + /*predicate=*/cur_realize_->predicate, // + /*analyzer=*/&analyzer_); int buffer_index = GetReadBufferIndex(cur_realize_->block, buffer); ICHECK(buffer_index != -1); - buffer_locs_[buffer.get()] = std::make_pair(cur_realize_->block, buffer_index); + buffer_loc_ = std::make_pair(cur_realize_->block, buffer_index); } } @@ -93,12 +86,12 @@ class BufferReadPosCollector : public StmtExprVisitor { } private: - /*! \brief All interested buffer. */ - std::unordered_set buffers_; - /*! \brief The result mapping from buffer to its inner-most block and read index. */ - std::unordered_map> buffer_locs_; - /*! \brief The result mapping from buffer to its IndexMap. */ - std::unordered_map> buffer_index_maps_; + /*! \brief The buffer of interest. */ + const BufferNode* buffer_; + /*! \brief The block that consumes the buffer and the corresponding read index. */ + std::pair buffer_loc_; + /*! \brief The proposed IndexMap. */ + Optional buffer_index_map_; /*! \brief Loop stack for calculating IndexMap. */ Array loop_stack_; @@ -143,8 +136,56 @@ Array CollectLayoutFreeBuffers(const PrimFuncNode* func) { return layout_free_buffers; } +std::optional> GetSuggestedIndexMap( + Buffer buffer, const PrimFuncNode* prim_func) { + BufferReadPosCollector collector(buffer); + collector(prim_func->body); + + const auto& index_map = collector.GetBufferIndexMap(); + + if (!index_map.defined() || !index_map) { + return std::nullopt; + } + + const auto& [anchor_block, buffer_index] = collector.GetBufferLocation(); + + return std::make_tuple(anchor_block, buffer_index, index_map.value()); +} + +/*! \brief Get a chain of cache-read blocks, starting from the one consuming buf. */ +std::vector GetCacheReadChain(const Buffer& buf, const PrimFuncNode* prim_func) { + class BufferReadChainCollector : public StmtVisitor { + public: + explicit BufferReadChainCollector(const Buffer& buffer) : cur_buffer_(buffer.get()) {} + + void VisitStmt_(const BlockNode* op) final { + // Check if this block is doing cache_read or a similar operation that consumes cur_buffer_. + if (!op->init && op->reads.size() == 1 && op->writes.size() == 1 && + op->reads[0]->buffer.get() == cur_buffer_) { + cache_read_chain.push_back(op->name_hint); + cur_buffer_ = op->writes[0]->buffer.get(); + } + StmtVisitor::VisitStmt_(op); + } + + std::vector cache_read_chain; + + private: + const BufferNode* cur_buffer_; + }; + + BufferReadChainCollector collector(buf); + collector(prim_func->body); + return collector.cache_read_chain; +} + bool RewriteLayout(const Schedule& sch) { std::vector> results; + auto add_layout_rewrite_block = [&sch](BlockRV consumer_block_rv, int buffer_index) { + BlockRV rewrite_block_rv = sch->CacheRead(consumer_block_rv, buffer_index, "global"); + sch->Annotate(rewrite_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true()); + }; + for (const auto& [g_var, base_func] : sch->mod()->functions) { const String& func_name = g_var->name_hint; const auto* prim_func = base_func.as(); @@ -153,36 +194,46 @@ bool RewriteLayout(const Schedule& sch) { continue; } - Array layout_free_buffers = CollectLayoutFreeBuffers(prim_func); - - // Collect Buffer read positions - BufferReadPosCollector collector(layout_free_buffers); - collector(prim_func->body); - const auto& locations = collector.GetBufferLocations(); - const auto& index_maps = collector.GetBufferIndexMap(); - // Check all buffers are collected - if (locations.size() != layout_free_buffers.size() || - index_maps.size() != layout_free_buffers.size()) { - return false; - } - - for (const auto& kv : locations) { - const Buffer& buffer = GetRef(kv.first); - const Block& block = kv.second.first; - int buffer_index = kv.second.second; - - // Get IndexMap - const Optional index_map = index_maps.at(buffer.get()); - if (!index_map.defined()) { - continue; + for (auto buffer : CollectLayoutFreeBuffers(prim_func)) { + const auto cache_read_chain = GetCacheReadChain(buffer, prim_func); + if (cache_read_chain.empty()) { + // The common case, where the layout-free buffer is directly consumed by an anchor op such + // as conv2d or dense. + auto tup_opt = GetSuggestedIndexMap(buffer, prim_func); + if (tup_opt == std::nullopt) continue; + + auto [anchor_block, buffer_index, index_map] = *tup_opt; + auto anchor_block_rv = sch->GetBlock(anchor_block->name_hint, func_name); + add_layout_rewrite_block(anchor_block_rv, buffer_index); + sch->TransformLayout(anchor_block_rv, buffer_index, BufferIndexType::kRead, index_map, + NullOpt); + } else { + // When the layout-free buffer is consumed by cache_read, we need to find the index map + // for a cache-read buffer that is directly consumed by an anchor op. The last buffer + // in cache_read_chain corresponds to that buffer. + Block cache_read_block = sch->Get(sch->GetBlock(cache_read_chain.back(), func_name)); + ICHECK_EQ(cache_read_block->writes.size(), 1); + auto tup_opt = GetSuggestedIndexMap(cache_read_block->writes[0]->buffer, prim_func); + if (tup_opt == std::nullopt) continue; + + auto [anchor_block, buffer_index, index_map] = *tup_opt; + // Transform the layout of the last cache-read buffer. + sch->TransformLayout(sch->GetBlock(anchor_block->name_hint, func_name), buffer_index, + BufferIndexType::kRead, index_map, NullOpt); + + // Propagate the layout transformation over cache_read_chain, starting from + // the next-to-last cache-read buffer. + for (int i = static_cast(cache_read_chain.size()) - 1; i >= 0; --i) { + BlockRV cache_read_block_rv = sch->GetBlock(cache_read_chain[i], func_name); + if (i == 0) { + // Before the first cache_read that consumes the layout-free buffer, insert + // a layout-rewrite block. Another cache-read buffer is added, and its layout is + // transformed by TransformLayout below. + add_layout_rewrite_block(cache_read_block_rv, 0); + } + sch->TransformLayout(cache_read_block_rv, 0, BufferIndexType::kRead, index_map, NullOpt); + } } - - // Apply schedule - BlockRV block_rv = sch->GetBlock(block->name_hint, func_name); - BlockRV cached_block_rv = sch->CacheRead(block_rv, buffer_index, "global"); - sch->TransformLayout(block_rv, buffer_index, BufferIndexType::kRead, index_map.value(), - NullOpt); - sch->Annotate(cached_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true()); } } return true; diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 80da5a727926..0581ad60e8f4 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -110,13 +110,20 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { Block block = Downcast(StmtMutator::VisitStmt_(_block)); BlockNode* n = block.CopyOnWrite(); if (Optional ann = n->annotations.Get(topi_attr)) { + Array new_buffers; for (Buffer buffer : Downcast>(ann)) { auto it = buffer2index_.find(buffer); if (it != buffer2index_.end()) { layout_free_buffer_indices_.insert(it->second); + } else { + new_buffers.push_back(buffer); } } - n->annotations.erase(topi_attr); + if (new_buffers.empty()) { + n->annotations.erase(topi_attr); + } else { + n->annotations.Set(topi_attr, new_buffers); + } } for (const String& attr : this->blocklist) { auto it = n->annotations.find(attr); diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py index 91a51c8e9033..98c1f7368580 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py @@ -204,5 +204,281 @@ def test_layout_rewrite(): tvm.ir.assert_structural_equal(sch.mod["main"], rewritten_tir_matmul) +# fmt: off +@tvm.script.ir_module +class Conv2dCacheRead: + @T.prim_func + def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(3, 3, 64, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float32"]): + T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + pad_temp = T.alloc_buffer([1, 58, 58, 64], dtype="float32") + conv2d_nhwc_global = T.alloc_buffer([1, 56, 56, 64], dtype="float32") + pad_temp_global = T.alloc_buffer([1, 58, 58, 64], dtype="float32") + p1_global = T.alloc_buffer([3, 3, 64, 64], dtype="float32") + for i0_0_i1_0_i2_0_fused in T.parallel(4, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}): + for ax0, ax1, ax2 in T.grid(1, 30, 30): + for ax3_fused in T.vectorized(64): + with T.block("pad_temp"): + i0 = T.axis.spatial(1, ax0) + i1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax1) + i2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax2) + i3 = T.axis.spatial(64, ax3_fused) + T.reads(p0[i0, i1 - 1, i2 - 1, i3]) + T.writes(pad_temp[i0, i1, i2, i3]) + pad_temp[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 57 and 1 <= i2 and i2 < 57, p0[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32") + for i3_0 in T.serial(16): + for ax0_ax1_ax2_ax3_fused in T.serial(57600): + with T.block("pad_temp_global"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax0_ax1_ax2_ax3_fused // 1920) + v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax0_ax1_ax2_ax3_fused % 1920 // 64) + v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) + T.reads(pad_temp[v0, v1, v2, v3]) + T.writes(pad_temp_global[v0, v1, v2, v3]) + pad_temp_global[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused in T.serial(2304): + with T.block("p1_global"): + v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 768) + v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 768 // 256) + v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 256 // 4) + v3 = T.axis.spatial(64, i3_0 * 4 + ax0_ax1_ax2_ax3_fused % 4) + T.reads(p1[v0, v1, v2, v3]) + T.writes(p1_global[v0, v1, v2, v3]) + p1_global[v0, v1, v2, v3] = p1[v0, v1, v2, v3] + for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 7, 2, 1): + for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 14, 2, 1, 4, 1): + for i3_3_fused_init in T.vectorized(2): + with T.block("conv2d_nhwc_init"): + nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2_init * 4 + i1_3_init) + xx = T.axis.spatial(56, i2_3_init + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2_init) + ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2_init * 2 + i3_3_fused_init) + T.reads() + T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0) + for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 14, 2, 3, 3, 32, 1, 4, 1): + for i3_3_fused in T.vectorized(2): + with T.block("conv2d_nhwc_update"): + nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2 * 4 + i1_3) + xx = T.axis.spatial(56, i2_3 + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2) + ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2 * 2 + i3_3_fused) + ry = T.axis.reduce(3, i4_0 * 3 + i4_1) + rx = T.axis.reduce(3, i5_0 * 3 + i5_1) + rc = T.axis.reduce(64, i6_0 * 32 + i6_1) + T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global[nn, yy + ry, xx + rx, rc], p1_global[ry, rx, rc, ff]) + T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global[nn, yy + ry, xx + rx, rc] * p1_global[ry, rx, rc, ff] + for ax0, ax1, ax2 in T.grid(1, 4, 14): + for ax3_fused in T.vectorized(4): + with T.block("conv2d_nhwc_global"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + ax1) + v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + ax2) + v3 = T.axis.spatial(64, i3_0 * 4 + ax3_fused) + T.reads(conv2d_nhwc_global[v0, v1, v2, v3]) + T.writes(conv2d_nhwc[v0, v1, v2, v3]) + conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] + + +@tvm.script.ir_module +class Conv2dCacheReadRewritten: + @T.prim_func + def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(3, 3, 64, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float32"]): + T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + pad_temp = T.alloc_buffer([1, 58, 58, 64], dtype="float32") + conv2d_nhwc_global = T.alloc_buffer([1, 56, 56, 64], dtype="float32") + pad_temp_global = T.alloc_buffer([1, 58, 58, 64], dtype="float32") + p1_global = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32") + p1_global_1 = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32") + for ax0, ax1, ax2, ax3 in T.grid(3, 3, 64, 64): + with T.block("p1_global"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(p1[v0, v1, v2, v3]) + T.writes(p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) + T.block_attr({"meta_schedule.layout_rewrite_preproc":True}) + p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1[v0, v1, v2, v3] + for i0_0_i1_0_i2_0_fused in T.parallel(4, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}): + for ax0, ax1, ax2 in T.grid(1, 30, 30): + for ax3_fused in T.vectorized(64): + with T.block("pad_temp"): + i0 = T.axis.spatial(1, ax0) + i1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax1) + i2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax2) + i3 = T.axis.spatial(64, ax3_fused) + T.reads(p0[i0, i1 - 1, i2 - 1, i3]) + T.writes(pad_temp[i0, i1, i2, i3]) + pad_temp[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 57 and 1 <= i2 and i2 < 57, p0[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32") + for i3_0 in T.serial(16): + for ax0_ax1_ax2_ax3_fused in T.serial(57600): + with T.block("pad_temp_global"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax0_ax1_ax2_ax3_fused // 1920) + v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax0_ax1_ax2_ax3_fused % 1920 // 64) + v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) + T.reads(pad_temp[v0, v1, v2, v3]) + T.writes(pad_temp_global[v0, v1, v2, v3]) + pad_temp_global[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused in T.serial(2304): + with T.block("p1_global"): + v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 768) + v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 768 // 256) + v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 256 // 4) + v3 = T.axis.spatial(64, i3_0 * 4 + ax0_ax1_ax2_ax3_fused % 4) + T.reads(p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) + T.writes(p1_global[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) + p1_global[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] + for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 7, 2, 1): + for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 14, 2, 1, 4, 1): + for i3_3_fused_init in T.vectorized(2): + with T.block("conv2d_nhwc_init"): + nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2_init * 4 + i1_3_init) + xx = T.axis.spatial(56, i2_3_init + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2_init) + ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2_init * 2 + i3_3_fused_init) + T.reads() + T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0) + for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 14, 2, 3, 3, 32, 1, 4, 1): + for i3_3_fused in T.vectorized(2): + with T.block("conv2d_nhwc_update"): + nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2 * 4 + i1_3) + xx = T.axis.spatial(56, i2_3 + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2) + ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2 * 2 + i3_3_fused) + ry = T.axis.reduce(3, i4_0 * 3 + i4_1) + rx = T.axis.reduce(3, i5_0 * 3 + i5_1) + rc = T.axis.reduce(64, i6_0 * 32 + i6_1) + T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global[nn, yy + ry, xx + rx, rc], p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2]) + T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global[nn, yy + ry, xx + rx, rc] * p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2] + for ax0, ax1, ax2 in T.grid(1, 4, 14): + for ax3_fused in T.vectorized(4): + with T.block("conv2d_nhwc_global"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + ax1) + v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + ax2) + v3 = T.axis.spatial(64, i3_0 * 4 + ax3_fused) + T.reads(conv2d_nhwc_global[v0, v1, v2, v3]) + T.writes(conv2d_nhwc[v0, v1, v2, v3]) + conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] + + +@tvm.script.ir_module +class Conv2dCacheReadMultipleRewritten: + @T.prim_func + def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(3, 3, 64, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float32"]): + T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + pad_temp = T.alloc_buffer([1, 58, 58, 64], dtype="float32") + conv2d_nhwc_global = T.alloc_buffer([1, 56, 56, 64], dtype="float32") + pad_temp_global = T.alloc_buffer([1, 58, 58, 64], dtype="float32") + p1_global = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32") + p1_global2 = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32", scope="global2") + p1_global_1 = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32") + for ax0, ax1, ax2, ax3 in T.grid(3, 3, 64, 64): + with T.block("p1_global"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(p1[v0, v1, v2, v3]) + T.writes(p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) + T.block_attr({"meta_schedule.layout_rewrite_preproc":True}) + p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1[v0, v1, v2, v3] + for ax0, ax1, ax2, ax3 in T.grid(3, 3, 64, 64): + with T.block("p1_global2"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) + T.writes(p1_global2[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) + p1_global2[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] + for i0_0_i1_0_i2_0_fused in T.parallel(4, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}): + for ax0, ax1, ax2 in T.grid(1, 30, 30): + for ax3_fused in T.vectorized(64): + with T.block("pad_temp"): + i0 = T.axis.spatial(1, ax0) + i1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax1) + i2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax2) + i3 = T.axis.spatial(64, ax3_fused) + T.reads(p0[i0, i1 - 1, i2 - 1, i3]) + T.writes(pad_temp[i0, i1, i2, i3]) + pad_temp[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 57 and 1 <= i2 and i2 < 57, p0[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32") + for i3_0 in T.serial(16): + for ax0_ax1_ax2_ax3_fused in T.serial(57600): + with T.block("pad_temp_global"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax0_ax1_ax2_ax3_fused // 1920) + v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax0_ax1_ax2_ax3_fused % 1920 // 64) + v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64) + T.reads(pad_temp[v0, v1, v2, v3]) + T.writes(pad_temp_global[v0, v1, v2, v3]) + pad_temp_global[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused in T.serial(2304): + with T.block("p1_global"): + v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 768) + v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 768 // 256) + v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 256 // 4) + v3 = T.axis.spatial(64, i3_0 * 4 + ax0_ax1_ax2_ax3_fused % 4) + T.reads(p1_global2[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) + T.writes(p1_global[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]) + p1_global[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1_global2[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] + for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 7, 2, 1): + for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 14, 2, 1, 4, 1): + for i3_3_fused_init in T.vectorized(2): + with T.block("conv2d_nhwc_init"): + nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2_init * 4 + i1_3_init) + xx = T.axis.spatial(56, i2_3_init + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2_init) + ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2_init * 2 + i3_3_fused_init) + T.reads() + T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0) + for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 14, 2, 3, 3, 32, 1, 4, 1): + for i3_3_fused in T.vectorized(2): + with T.block("conv2d_nhwc_update"): + nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2 * 4 + i1_3) + xx = T.axis.spatial(56, i2_3 + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2) + ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2 * 2 + i3_3_fused) + ry = T.axis.reduce(3, i4_0 * 3 + i4_1) + rx = T.axis.reduce(3, i5_0 * 3 + i5_1) + rc = T.axis.reduce(64, i6_0 * 32 + i6_1) + T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global[nn, yy + ry, xx + rx, rc], p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2]) + T.writes(conv2d_nhwc_global[nn, yy, xx, ff]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global[nn, yy + ry, xx + rx, rc] * p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2] + for ax0, ax1, ax2 in T.grid(1, 4, 14): + for ax3_fused in T.vectorized(4): + with T.block("conv2d_nhwc_global"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + ax1) + v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + ax2) + v3 = T.axis.spatial(64, i3_0 * 4 + ax3_fused) + T.reads(conv2d_nhwc_global[v0, v1, v2, v3]) + T.writes(conv2d_nhwc[v0, v1, v2, v3]) + conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] + +# fmt: on + + +def test_layout_rewrite_cache_read(): + target = Target("llvm") + ctx = _create_context(Conv2dCacheRead, target) + sch = tvm.tir.Schedule(Conv2dCacheRead, debug_mask="all") + sch.enter_postproc() + assert ctx.space_generator.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, Conv2dCacheReadRewritten) + + +def test_layout_rewrite_cache_read_multiple(): + target = Target("llvm") + ctx = _create_context(Conv2dCacheRead, target) + sch = tvm.tir.Schedule(Conv2dCacheRead, debug_mask="all") + sch.cache_read(sch.get_block("p1_global"), 0, "global2") + sch.enter_postproc() + assert ctx.space_generator.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, Conv2dCacheReadMultipleRewritten) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index b59880758e5d..7b8173d0b2d9 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -390,6 +390,7 @@ def expected_layout_attr( C[x, y] = C[x, y] + A[x, k] * B[y, k] for i0, i1 in T.grid(128, 128): with T.block("D"): + T.block_attr({"layout_free_placeholders": [C]}) x, y = T.axis.remap("SS", [i0, i1]) D[x, y] = C[x, y] + T.float32(1)