From 570849f4a6b9cba921b52d3b8dbeb3192be97b85 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Dec 2022 08:30:25 -0800 Subject: [PATCH 01/13] upd --- include/tvm/tir/schedule/schedule.h | 26 + python/tvm/tir/schedule/schedule.py | 52 ++ src/tir/schedule/concrete_schedule.cc | 24 + src/tir/schedule/concrete_schedule.h | 4 + src/tir/schedule/primitive.h | 31 + .../schedule/primitive/cache_read_write.cc | 702 ++++++++++++++++++ src/tir/schedule/schedule.cc | 4 + src/tir/schedule/traced_schedule.cc | 30 + src/tir/schedule/traced_schedule.h | 4 + 9 files changed, 877 insertions(+) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 8b22c173a3d8..f322188f52c5 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -405,6 +405,32 @@ class ScheduleNode : public runtime::Object { virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, const Array consumer_blocks = {}) = 0; + /*! + * \brief Create a block that reads a buffer region into a read cache. It requires: + * 1) There is at most one block who writes the buffer in the scope. + * 2) The scope block have stage-pipeline property. + * Compared to cache read, the index mapping was performed at producer rather than consumer. + * \param block_rv The consumer block of the target buffer. + * \param read_buffer_index The index of the buffer in block's read region. + * \param storage_scope The target storage scope. + * \param dim_order The user-defined dimension order of allocated buffer. + * \return The cache stage block. + */ + virtual BlockRV ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope, Array dim_order) = 0; + /*! + * \brief Create a block that writes a buffer region into a write cache. It requires: + * 1) There is only one block who writes the target buffer. + * 2) The scope block have stage-pipeline property. + * Compared to cache write, the index mapping was performed at consumer rather than producer. + * \param block_rv The producer of the buffer + * \param write_buffer_index The index of the buffer in block's write region + * \param storage_scope The target storage scope + * \param dim_order The user-defined dimension order of allocated buffer. + * \return The cache stage block. + */ + virtual BlockRV ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope, Array dim_order) = 0; /*! * \brief Create 2 blocks that read&write a buffer region into a read/write cache. * It requires the the target block both read & write the target buffer. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 48850012cbb7..cec6b5f70508 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1198,6 +1198,58 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: self, block, write_buffer_index, storage_scope, consumer_blocks ) + @type_checked + def reverse_cache_read( + self, block: BlockRV, read_buffer_index: int, storage_scope: str, dim_order: List[int] = [] + ) -> BlockRV: + """Create a block that reads a buffer region into a read cache. + The index mapping was performed at producer rather than consumer. + Parameters + ---------- + block : BlockRV + The consumer block of the target buffer. + read_buffer_index: int + The index of the buffer in block's read region. + storage_scope: str + The target storage scope. + dim_order: List[int] + The user-defined dimension order of allocated buffer. + Numbers indicate the index of block iter vars. + Returns + ------- + cached_block : BlockRV + The block of the cache stage + """ + return _ffi_api.ScheduleReverseCacheRead( # type: ignore # pylint: disable=no-member + self, block, read_buffer_index, storage_scope, dim_order + ) + + @type_checked + def reverse_cache_write( + self, block: BlockRV, write_buffer_index: int, storage_scope: str, dim_order: List[int] = [] + ) -> BlockRV: + """Create a block that reads a buffer region into a write cache. + The index mapping was performed at consumer rather than producer. + Parameters + ---------- + block : BlockRV + The consumer block of the target buffer. + write_buffer_index: int + The index of the buffer in block's write region. + storage_scope: str + The target storage scope. + dim_order: List[int] + The user-defined dimension order of allocated buffer. + Numbers indicate the index of block iter vars. + Returns + ------- + cached_block : BlockRV + The block of the cache stage + """ + return _ffi_api.ScheduleReverseCacheWrite( # type: ignore # pylint: disable=no-member + self, block, write_buffer_index, storage_scope, dim_order + ) + @type_checked def cache_inplace( self, diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 163c72eb0777..9f6868a85ee3 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -568,6 +568,30 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff return CreateRV(result); } +BlockRV ConcreteScheduleNode::ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope, + Array dim_order) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::ReverseCacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope, + dim_order); + TVM_TIR_SCHEDULE_END("reverse-cache-read", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + +BlockRV ConcreteScheduleNode::ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope, + Array dim_order) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::ReverseCacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, + storage_scope, dim_order); + TVM_TIR_SCHEDULE_END("reverse-cache-write", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + Array ConcreteScheduleNode::CacheInplace(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) { Array results; diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 899775f2a15d..20d7ce9cc7ec 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -116,6 +116,10 @@ class ConcreteScheduleNode : public ScheduleNode { const Array consumer_blocks = {}) override; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, const Array consumer_blocks = {}) override; + BlockRV ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope, Array dim_order) override; + BlockRV ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope, Array dim_order) override; Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) override; Array CacheIndex(const BlockRV& block_rv, int write_buffer_index) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 9e7f77f55ea5..4685a0a7066a 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -269,6 +269,37 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, const String& storage_scope, const Array consumer_blocks = {}); +/*! + * \brief Create a block that reads a buffer region into a read cache. It requires: + * 1) There is at most one block who writes the buffer in the scope. + * 2) The scope block have stage-pipeline property. + * Compared to cache read, the index mapping was performed at producer instead of consumer. + * \param self The state of the schedule + * \param block_sref The consumer block of the target buffer. + * \param read_buffer_index The index of the buffer in block's read region. + * \param storage_scope The target storage scope. + * \param dim_order The user-defined dimension order of allocated buffer. + * \return The cache stage block. + */ +TVM_DLL StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, + int read_buffer_index, const String& storage_scope, + Array dim_order); +/*! + * \brief Create a block that writes a buffer region into a write cache. It requires: + * 1) There is only one block that writes the target buffer. + * 2) The scope block have stage-pipeline property. + * Compared to cache write, the index mapping was performed at producer instead of consumer. + * \param self The state of the schedule + * \param block_sref The producer of the buffer + * \param write_buffer_index The index of the buffer in block's write region + * \param storage_scope The target storage scope + * \param dim_order The user-defined dimension order of allocated buffer. + * \return The cache stage block. + */ +TVM_DLL StmtSRef ReverseCacheWrite(ScheduleState self, const StmtSRef& block_sref, + int write_buffer_index, const String& storage_scope, + Array dim_order); + /*! *! * \brief Create 2 blocks that read&write a buffer region into a read/write cache. diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 4174a6699e06..25acc0218c4a 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -94,6 +94,126 @@ Optional GetBufferRegionFromBuffer(const Array& buff return res; } +struct ReverseCacheTouchedInfo { + /* Whether read or write. */ + bool read; + /* Touched loop variable related information. */ + Array loop_vars; + Array loop_ranges; + /* Touched block variable related information. */ + Array block_vars; + Array iter_values; +}; + +// /*! \brief Return the buffer region related with the buffer */ +// Optional GetBufferRegionFromBuffer(const Array& buffer_regions, +// const Buffer& buffer) { +// Optional res = NullOpt; +// for (const auto& region : buffer_regions) { +// if (region->buffer.same_as(buffer)) { +// ICHECK(!res.defined()); +// res = region; +// } +// } +// return res; +// } + +/*! + * \brief Create a loop nest that represents reverse cache copy (reverse_read / reverse_write) from + * read buffer to write buffer. \param cache_region The cached copy region. \param info The cache + * stage information, which will be updated in the function. \param storage_scope The storage scope + * of the cached buffer (only used in naming here) \returns A block indicating the body of the loop + * nesting. + */ +Block MakeReverseCacheStage(const BufferRegion& cache_region, ReverseCacheTouchedInfo* touched_info, + CacheStageInfo* info, const String& storage_scope) { + /* Check whether cache region is a single point. */ + bool single_point = true; + for (const Range& range : cache_region->region) { + const auto* ext_int = range->extent.as(); + if (!ext_int || ext_int->value != 1) { + single_point = false; + } + } + CHECK(single_point) << "ReverseCacheStage is only valid when cache region is a single point."; + // loop variables + std::vector loop_vars; + // block variables + Array block_vars; + // bindings in block realize + std::vector iter_values; + // Create loop vars and block vars' binding_value + Map var_map; + for (size_t i = 0; i < touched_info->loop_vars.size(); ++i) { + Var original_var = touched_info->loop_vars[i]; + Var loop_var("ax" + std::to_string(i), original_var.dtype()); + var_map.Set(original_var, loop_var); + loop_vars.push_back(loop_var); + } + for (size_t i = 0; i < touched_info->block_vars.size(); ++i) { + IterVar original_block_var = touched_info->block_vars[i]; + PrimExpr original_iter_value = touched_info->iter_values[i]; + IterVar block_var = IterVar( + /*dom=*/original_block_var->dom, + /*var=*/Var("v" + std::to_string(i), original_block_var->var.dtype()), + /*IterVarType=*/kDataPar); + var_map.Set(original_block_var->var, block_var->var); + block_vars.push_back(block_var); + iter_values.push_back(Substitute(original_iter_value, var_map)); + } + + // block access region for read/write buffers + Region read_access_region, write_access_region; + Array read_access_indices, write_access_indices; + // Compute read/write region and read/write access indices. + for (const Range& range : cache_region->region) { + if (touched_info->read) { + read_access_indices.push_back(Substitute(range->min, var_map)); + read_access_region.push_back(Range::FromMinExtent(read_access_indices.back(), Integer(1))); + } else { + write_access_indices.push_back(Substitute(range->min, var_map)); + write_access_region.push_back(Range::FromMinExtent(write_access_indices.back(), Integer(1))); + } + } + for (const IterVar& block_var : block_vars) { + if (touched_info->read) { + write_access_indices.push_back(block_var->var); + write_access_region.push_back(Range::FromMinExtent(write_access_indices.back(), Integer(1))); + } else { + read_access_indices.push_back(block_var->var); + read_access_region.push_back(Range::FromMinExtent(read_access_indices.back(), Integer(1))); + } + } + + // Create New Block + Block block( + /*iter_vars*/ std::move(block_vars), + /*reads=*/{BufferRegion(info->read_buffer, read_access_region)}, + /*writes=*/{BufferRegion(info->write_buffer, write_access_region)}, + /*name_hint*/ cache_region->buffer->name + "_" + storage_scope, + /*body=*/ + BufferStore(info->write_buffer, BufferLoad(info->read_buffer, read_access_indices), + write_access_indices), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*buf_doms=*/{}); + // Create Block Realize node + Stmt body = BlockRealize(/*values=*/iter_values, + /*predicate=*/const_true(), + /*block=*/block); + // Create surrounding loops + for (size_t i = loop_vars.size(); i >= 1; --i) { + body = For(/*loop_var=*/loop_vars[i - 1], + /*min=*/touched_info->loop_ranges[i - 1]->min, + /*extent=*/touched_info->loop_ranges[i - 1]->extent, + /*kind=*/ForKind::kSerial, + /*body=*/body); + } + info->cache_stage = std::move(body); + return block; +} + /*! * \brief Create a loop nest that represents cache copy (cache_read / cache_write) from read buffer * to write buffer. @@ -600,6 +720,120 @@ class CacheInplaceLocDetector : public StmtVisitor { int loc_pos_{-1}; }; +/*! \brief Mutator for ReverseCacheRead. */ +class ReverseCacheReadRewriter : public StmtExprMutator { + public: + /*! + * \brief Rewrite the AST and add a cache_read stage with the information provided. + * \param scope_sref The parent scope of this mutation. + * \param info The cache stage information. + * \param touched_info The reverse cache touched information. + * \return The new AST rooting at the original parent scope. + */ + static Stmt Rewrite(const StmtSRef& scope_sref, CacheStageInfo* info, + ReverseCacheTouchedInfo* touched_info) { + ReverseCacheReadRewriter rewriter(scope_sref, info, touched_info); + return rewriter(GetRef(scope_sref->stmt)); + } + + private: + explicit ReverseCacheReadRewriter(const StmtSRef& scope_sref, CacheStageInfo* info, + ReverseCacheTouchedInfo* touched_info) + : scope_sref_(scope_sref), info_(info) { + for (const IterVar& iter_var : touched_info->block_vars) { + new_indices_.push_back(iter_var->var); + } + } + + Stmt VisitStmt_(const ForNode* loop) final { + Stmt stmt = StmtMutator::VisitStmt_(loop); + // Check the insertion point + if (loop == info_->loc_sref->stmt) { + // Insert cache stage into the loop if it is the right place + ObjectPtr n = make_object(*stmt.as()); + n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); + stmt = Stmt(n); + } + return stmt; + } + + Stmt VisitStmt_(const BlockRealizeNode* block_realize) final { + Stmt stmt = StmtMutator::VisitStmt_(block_realize); + return stmt; + } + + Stmt VisitStmt_(const BlockNode* block) final { + Block old_stmt = GetRef(block); + if (block != scope_sref_->stmt && + GetBufferRegionFromBuffer(block->writes, info_->read_buffer).defined()) { + return std::move(old_stmt); + } + // Mutate the body + Block stmt = Downcast(StmtMutator::VisitStmt_(block)); + // Check the insertion point + if (block == info_->loc_sref->stmt) { + // Insert cache stage into the block if it is the right place + ObjectPtr n = make_object(*stmt.as()); + n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); + stmt = Block(n); + } + // Check if it is the block corresponding to the parent scope + if (block == scope_sref_->stmt) { + // If so, put buffer allocation on the parent scope + ObjectPtr n = make_object(*stmt.as()); + n->alloc_buffers.push_back(info_->alloc.value()); + stmt = Block(n); + } else { + // Otherwise, update read regions and match_buffers + Array reads; + for (const BufferRegion& buf_region : block->reads) { + if (buf_region->buffer.same_as(info_->read_buffer)) { + Array region; + for (const PrimExpr index : new_indices_) { + region.push_back(Range::FromMinExtent(index, Integer(1))); + } + reads.push_back(BufferRegion(info_->write_buffer, region)); + } else { + reads.push_back(buf_region); + } + } + + // NOTE(Zihao): do not process match buffers for now. + if (!reads.same_as(block->reads)) { + ObjectPtr n = make_object(*stmt.as()); + n->reads = std::move(reads); + stmt = Block(n); + } + } + info_->block_reuse.Set(old_stmt, stmt); + return std::move(stmt); + } + + PrimExpr VisitExpr_(const VarNode* op) final { + if (op == info_->read_buffer->data.get()) { + return info_->write_buffer->data; + } + return GetRef(op); + } + + PrimExpr VisitExpr_(const BufferLoadNode* load) final { + if (load->buffer.same_as(info_->read_buffer)) { + ObjectPtr n = make_object(*load); + n->buffer = info_->write_buffer; + n->indices = new_indices_; + return PrimExpr(n); + } + return ExprMutator::VisitExpr_(load); + } + + /*! \brief The parent scope of the insertion */ + const StmtSRef& scope_sref_; + /*! \brief The info for inserting cache stage */ + CacheStageInfo* info_; + /*! \brief The indices to use for new buffer. */ + Array new_indices_; +}; + /*! \brief Mutator for CacheRead. */ class CacheReadRewriter : public StmtExprMutator { public: @@ -717,6 +951,158 @@ class CacheReadRewriter : public StmtExprMutator { bool current_block_consumes; }; +/*! \brief Mutator for ReverseCacheWrite. */ +class ReverseCacheWriteRewriter : public StmtExprMutator { + public: + /*! + * \brief Rewrite the AST and add a cache_write stage with the information provided. + * \param scope_sref The parent scope of this mutation. + * \param writer_block_sref The only writer block in the scope. + * \param info The cache stage information. + * \param touched_info The reverse cache touched information. + * \return The new AST rooting at the original parent scope. + */ + static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, + CacheStageInfo* info, ReverseCacheTouchedInfo* touched_info) { + ReverseCacheWriteRewriter rewriter(scope_sref, writer_block_sref, info, touched_info); + return rewriter(GetRef(scope_sref->stmt)); + } + + private: + explicit ReverseCacheWriteRewriter(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, + CacheStageInfo* info, ReverseCacheTouchedInfo* touched_info) + : scope_sref_(scope_sref), writer_block_sref_(writer_block_sref), info_(info) { + for (const IterVar& iter_var : touched_info->block_vars) { + new_indices_.push_back(iter_var->var); + } + } + + Stmt VisitStmt_(const ForNode* loop) final { + Stmt stmt = StmtMutator::VisitStmt_(loop); + // Check the insertion point + if (loop == info_->loc_sref->stmt) { + // Insert cache stage into the loop if it is the right place + ObjectPtr n = make_object(*stmt.as()); + n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); + stmt = Stmt(n); + } + return stmt; + } + + Stmt VisitStmt_(const BlockRealizeNode* block_realize) final { + Stmt stmt = StmtMutator::VisitStmt_(block_realize); + return stmt; + } + + Stmt VisitStmt_(const BlockNode* block) final { + Block old_stmt = GetRef(block); + // We only mutate the block which generates info->write_buffer + if (block != writer_block_sref_->stmt && block != scope_sref_->stmt && !under_writer_block_) { + return std::move(old_stmt); + } + + // Mutate the body + bool under_scope = under_writer_block_ || block == writer_block_sref_->stmt; + std::swap(under_scope, under_writer_block_); + Block stmt = Downcast(StmtMutator::VisitStmt_(block)); + std::swap(under_scope, under_writer_block_); + + // Find the insertion point + if (block == info_->loc_sref->stmt) { + ObjectPtr n = make_object(*stmt.as()); + n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); + stmt = Block(n); + } + // Put buffer allocation on the parent scope + if (block == scope_sref_->stmt) { + ObjectPtr n = make_object(*stmt.as()); + n->alloc_buffers.push_back(info_->alloc.value()); + stmt = Block(n); + } else { + // Since cache_write changes the block, we need to update the buffer it writes + Array reads, writes; + // New reads info. + for (const BufferRegion& buf_region : block->reads) { + if (buf_region->buffer.same_as(info_->write_buffer)) { + Array region; + for (const PrimExpr index : new_indices_) { + region.push_back(Range::FromMinExtent(index, Integer(1))); + } + reads.push_back(BufferRegion(info_->read_buffer, region)); + } else { + reads.push_back(buf_region); + } + } + // New writes info, same as above. + for (const BufferRegion& buf_region : block->writes) { + if (buf_region->buffer.same_as(info_->write_buffer)) { + Array region; + for (const PrimExpr index : new_indices_) { + region.push_back(Range::FromMinExtent(index, Integer(1))); + } + writes.push_back(BufferRegion(info_->read_buffer, region)); + } else { + writes.push_back(buf_region); + } + } + + // NOTE(Zihao): do not process match buffers for now. + if (!writes.same_as(block->writes) || !reads.same_as(block->reads)) { + ObjectPtr n = make_object(*stmt.as()); + n->writes = std::move(writes); + n->reads = std::move(reads); + stmt = Block(n); + } + } + // Remove atomic flag + ObjectPtr n = make_object(*stmt.as()); + n->annotations.erase("atomic"); + stmt = Block(n); + info_->block_reuse.Set(old_stmt, stmt); + return std::move(stmt); + } + + PrimExpr VisitExpr_(const VarNode* op) final { + if (op == info_->write_buffer->data.get()) { + return info_->read_buffer->data; + } + return GetRef(op); + } + + Stmt VisitStmt_(const BufferStoreNode* store) final { + BufferStore stmt = Downcast(StmtMutator::VisitStmt_(store)); + if (stmt->buffer.same_as(info_->write_buffer)) { + auto n = CopyOnWrite(stmt.get()); + n->buffer = info_->read_buffer; + n->indices = new_indices_; + return Stmt(n); + } else { + return std::move(stmt); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* load) final { + if (load->buffer.same_as(info_->write_buffer)) { + ObjectPtr n = make_object(*load); + n->buffer = info_->read_buffer; + n->indices = new_indices_; + return PrimExpr(n); + } + return ExprMutator::VisitExpr_(load); + } + + /*! \brief The parent scope of the insertion. */ + const StmtSRef& scope_sref_; + /*! \brief The parent scope of the insertion. */ + const StmtSRef& writer_block_sref_; + /*! \brief The info for inserting cache stage. */ + CacheStageInfo* info_; + /*! \brief The indices to use for new buffer. */ + Array new_indices_; + /*! \brief Whether the current node is under the given block. */ + bool under_writer_block_{false}; +}; + /*! \brief Mutator for CacheWrite */ class CacheWriteRewriter : public StmtExprMutator { public: @@ -1291,6 +1677,261 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu return result_block_sref; } +class VarCollector : public ExprVisitor { + public: + VarCollector() {} + std::unordered_set touched; + + private: + void VisitExpr_(const VarNode* op) final { touched.insert(op); } +}; + +Array GetLoopsUnderScope(const StmtSRef& block_sref, const StmtSRef& top_sref) { + std::vector result; + for (StmtSRefNode* parent = block_sref->parent; parent && parent->stmt->IsInstance(); + parent = parent->parent) { + if (parent == top_sref.get()) break; + result.push_back(GetRef(parent)); + } + return {result.rbegin(), result.rend()}; +} + +StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, + const String& storage_scope, Array dim_order) { + /*! + * Check: + * - The index is in the array of block reading region + * - There is at most one block who write the buffer in the scope + * + * Mutate: + * - Allocate new cache buffer under the current scope. + * - Find the lowest ancestor of the block and ANY ONE of the consumers blocks. + * - Copy the buffer with the consumed region. + */ + + // Step 0. Check the input storage scope. + CheckStorageScope(self, storage_scope); + + // Step 1. Check index, getting the target buffer and the parent scope + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + Buffer read_buffer = + GetNthAccessBuffer(self, GetRef(block), read_buffer_index, BufferIndexType::kRead); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); + + // Step 2. Create CacheStageInfo + CacheStageInfo info; + info.read_buffer = read_buffer; + + // Step 3. Update cache stage info. + Optional maybe_region = GetBufferRegionFromBuffer(block->reads, read_buffer); + ICHECK(maybe_region.defined()) << read_buffer + << " should appear in the block's read region: " << block->reads; + BufferRegion cache_region = maybe_region.value(); + if (Optional _write_block_sref = GetOnlyWriteBlock(self, scope_sref, read_buffer)) { + // Case 1. The buffer is written inside the block. + StmtSRef write_block_sref = _write_block_sref.value(); + const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref); + // Find the producing region + StmtSRef parent_sref = GetRef(write_block_sref->parent); + // Detect insert position + CacheLocDetector::Detect(self, write_block_sref, scope_sref, &info); + } else { + // Case 2. The buffer is the input block for the scope. + info.loc_sref = scope_sref; + info.loc_pos = 0; + } + + // Step 4. Create CacheTouchedInfo + ReverseCacheTouchedInfo touched_info; + + // Step 5. Update CacheTouchedInfo + touched_info.read = true; + VarCollector collector; + Array new_shape; + for (const Range& range : cache_region->region) { + collector(range->min); + } + BlockRealize realize = GetBlockRealize(self, block_sref); + std::unordered_set dim_order_set; + for (const Integer& idx : dim_order) { + dim_order_set.insert(idx->value); + } + for (size_t idx = 0; idx < block->iter_vars.size(); ++idx) { + const IterVar& block_var = block->iter_vars[idx]; + if (collector.touched.count(block_var->var.get())) { + if (dim_order_set.empty()) { + // no user defined dim order. + dim_order.push_back(idx); + } else { + // user provide dim order, check whether it's valid. + CHECK(dim_order_set.count(idx)) + << "Block iter_var " << block_var + << " used in the block, but doesn't appear in user-specified dim order array."; + } + } + } + + for (size_t i = 0; i < dim_order.size(); ++i) { + int idx = dim_order[i]->value; + const IterVar& block_var = block->iter_vars[idx]; + touched_info.block_vars.push_back(block_var); + touched_info.iter_values.push_back(realize->iter_values[idx]); + new_shape.push_back(block_var->dom->min + block_var->dom->extent); + collector(touched_info.iter_values.back()); + } + + for (const StmtSRef& loop_sref : GetLoopsUnderScope(block_sref, info.loc_sref)) { + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); + if (collector.touched.count(loop->loop_var.get())) { + touched_info.loop_vars.push_back(loop->loop_var); + touched_info.loop_ranges.push_back(Range::FromMinExtent(loop->min, loop->extent)); + } + } + + // Create write buffer. + ObjectPtr new_buffer = make_object(*read_buffer.get()); + ObjectPtr new_var = make_object(*read_buffer->data.get()); + const auto* ptr_type = TVM_TYPE_AS(read_buffer->data->type_annotation, PointerTypeNode); + new_var->type_annotation = PointerType(ptr_type->element_type, storage_scope); + new_buffer->data = Var(new_var->name_hint + "_" + storage_scope, new_var->type_annotation); + new_buffer->name = read_buffer->name + "_" + storage_scope; + new_buffer->shape = new_shape; + + info.write_buffer = Buffer(new_buffer); + info.alloc = info.write_buffer; + + // Step 6. Making new cache stage block and rewrite readers. + Block cache_read_stage = MakeReverseCacheStage(/*cache_region=*/cache_region, + /*touched_info=*/&touched_info, /*info=*/&info, + /*storage_scope=*/storage_scope); + Stmt new_scope = ReverseCacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info, + /*touched_info=*/&touched_info); + + // Step 7. Replacing and updating flags. + self->Replace(scope_sref, new_scope, info.block_reuse); + StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get()); + BlockInfo& block_info = self->block_info[result_block_sref]; + block_info.affine_binding = CalculateAffineFlag(self, result_block_sref); + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + return result_block_sref; +} + +StmtSRef ReverseCacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, + const String& storage_scope, Array dim_order) { + /*! + * Check: + * - The index is in the array of block reading region + * - There is only one block who write the buffer in the scope + * + * Mutate: + * - Allocate new cache buffer under the current scope. + * - Find the lowest ancestor of the block and ANY ONE of the producer blocks. + * - Copy the buffer with the consumed region. + */ + + // Step 0. Check the input storage scope. + CheckStorageScope(self, storage_scope); + + // Step 1. Checking index, getting the target buffer and the parent scope + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + Buffer write_buffer = + GetNthAccessBuffer(self, GetRef(block), write_buffer_index, BufferIndexType::kWrite); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + + // Step 2. Creating CacheStageInfo + CacheStageInfo info; + info.write_buffer = write_buffer; + + // Step 3. Check the only writer block. + ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get()); + + // Step 4. Find the producing region and insert position + Optional maybe_region = GetBufferRegionFromBuffer(block->writes, write_buffer); + ICHECK(maybe_region.defined()) << write_buffer << " should appear in the block's write region"; + StmtSRef parent_sref = GetRef(block_sref->parent); + // Detect insert position + CacheLocDetector::Detect(self, block_sref, scope_sref, &info); + BufferRegion cache_region = maybe_region.value(); + + // Step 5. Create CacheTouchedInfo + ReverseCacheTouchedInfo touched_info; + + // Step 6. Update CacheTouchedInfo + touched_info.read = false; + VarCollector collector; + Array new_shape; + for (const Range& range : cache_region->region) { + collector(range->min); + } + BlockRealize realize = GetBlockRealize(self, block_sref); + std::unordered_set dim_order_set; + for (const Integer& idx : dim_order) { + dim_order_set.insert(idx->value); + } + for (size_t idx = 0; idx < block->iter_vars.size(); ++idx) { + const IterVar& block_var = block->iter_vars[idx]; + if (collector.touched.count(block_var->var.get())) { + if (dim_order_set.empty()) { + // no user defined dim order. + dim_order.push_back(idx); + } else { + // user provide dim order, check whether it's valid. + CHECK(dim_order_set.count(idx)) + << "Block iter_var " << block_var + << " used in the block, but doesn't appear in user-specified dim order array."; + } + } + } + + for (size_t i = 0; i < dim_order.size(); ++i) { + int idx = dim_order[i]->value; + const IterVar& block_var = block->iter_vars[idx]; + touched_info.block_vars.push_back(block_var); + touched_info.iter_values.push_back(realize->iter_values[idx]); + new_shape.push_back(block_var->dom->min + block_var->dom->extent); + collector(touched_info.iter_values.back()); + } + + for (const StmtSRef& loop_sref : GetLoopsUnderScope(block_sref, info.loc_sref)) { + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); + if (collector.touched.count(loop->loop_var.get())) { + touched_info.loop_vars.push_back(loop->loop_var); + touched_info.loop_ranges.push_back(Range::FromMinExtent(loop->min, loop->extent)); + } + } + + // Create write buffer. + ObjectPtr new_buffer = make_object(*write_buffer.get()); + ObjectPtr new_var = make_object(*write_buffer->data.get()); + const auto* ptr_type = TVM_TYPE_AS(write_buffer->data->type_annotation, PointerTypeNode); + new_var->type_annotation = PointerType(ptr_type->element_type, storage_scope); + new_buffer->data = Var(new_var->name_hint + "_" + storage_scope, new_var->type_annotation); + new_buffer->name = write_buffer->name + "_" + storage_scope; + new_buffer->shape = new_shape; + + info.read_buffer = Buffer(new_buffer); + info.alloc = info.read_buffer; + + // Step 7. Making new cache stage block and rewrite readers. + Block cache_write_stage = MakeReverseCacheStage(/*cache_region=*/cache_region, + /*touched_info=*/&touched_info, /*info=*/&info, + /*storage_scope=*/storage_scope); + Stmt new_scope = ReverseCacheWriteRewriter::Rewrite( + /*scope_sref=*/scope_sref, + /*writer_block_sref=*/block_sref, /*info=*/&info, /*touched_info=*/&touched_info); + + // Step 8. Replacing and updating flags. + self->Replace(scope_sref, new_scope, info.block_reuse); + StmtSRef result_block_sref = self->stmt2ref.at(cache_write_stage.get()); + BlockInfo& block_info = self->block_info[result_block_sref]; + block_info.affine_binding = CalculateAffineFlag(self, result_block_sref); + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + return result_block_sref; +} + /*! \brief The schedule error that the target block doesn't both read&write target buffer. */ class NotReadWriteError : public ScheduleError { public: @@ -1600,9 +2241,70 @@ struct ReIndexTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct ReverseCacheReadTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReverseCacheRead"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 3; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer read_buffer_index, + String storage_scope, Array dim_order) { + return sch->ReverseCacheRead(block, read_buffer_index->value, storage_scope, dim_order); + } + + static String UnpackedAsPython(Array outputs, String block, Integer read_buffer_index, + String storage_scope, Array dim_order) { + PythonAPICall py("reverse_cache_read"); + py.Input("block", block); + py.Input("read_buffer_index", read_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.Input("dim_order", dim_order); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct ReverseCacheWriteTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReverseCacheWrite"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 3; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer write_buffer_index, + String storage_scope, Array dim_order) { + return sch->ReverseCacheWrite(block, write_buffer_index->value, storage_scope, dim_order); + } + + static String UnpackedAsPython(Array outputs, String block, Integer write_buffer_index, + String storage_scope, Array dim_order) { + PythonAPICall py("reverse_cache_write"); + py.Input("block", block); + py.Input("write_buffer_index", write_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.Input("dim_order", dim_order); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(CacheReadTraits); TVM_REGISTER_INST_KIND_TRAITS(CacheWriteTraits); TVM_REGISTER_INST_KIND_TRAITS(CacheInplaceTraits); TVM_REGISTER_INST_KIND_TRAITS(ReIndexTraits); +TVM_REGISTER_INST_KIND_TRAITS(ReverseCacheReadTraits); +TVM_REGISTER_INST_KIND_TRAITS(ReverseCacheWriteTraits); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index d008f3639c78..6b47397c7fc3 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -179,6 +179,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead") .set_body_method(&ScheduleNode::CacheRead); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") .set_body_method(&ScheduleNode::CacheWrite); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseCacheRead") + .set_body_method(&ScheduleNode::ReverseCacheRead); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseCacheWrite") + .set_body_method(&ScheduleNode::ReverseCacheWrite); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheInplace") .set_body_method(&ScheduleNode::CacheInplace); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheIndex") diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 70559608e789..079e5057f7b1 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -309,6 +309,36 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer return result; } +BlockRV TracedScheduleNode::ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope, + Array dim_order) { + BlockRV result = + ConcreteScheduleNode::ReverseCacheRead(block_rv, read_buffer_index, storage_scope, dim_order); + + static const InstructionKind& kind = InstructionKind::Get("ReverseCacheRead"); + trace_->Append( + /*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(read_buffer_index), storage_scope, dim_order}, + /*outputs=*/{result})); + return result; +} + +BlockRV TracedScheduleNode::ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope, + Array dim_order) { + BlockRV result = ConcreteScheduleNode::ReverseCacheWrite(block_rv, write_buffer_index, + storage_scope, dim_order); + + static const InstructionKind& kind = InstructionKind::Get("ReverseCacheWrite"); + trace_->Append( + /*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(write_buffer_index), storage_scope, dim_order}, + /*outputs=*/{result})); + return result; +} + Array TracedScheduleNode::CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) { Array result = diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index c54574e9c9ff..2c6a6f88827a 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -76,6 +76,10 @@ class TracedScheduleNode : public ConcreteScheduleNode { const Array consumer_blocks = {}) final; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, const Array consumer_blocks = {}) final; + BlockRV ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope, Array dim_order) final; + BlockRV ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope, Array dim_order) final; Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) final; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, From 8aecf7a07e13cc182d0a6d225c1c908e5ecc8d0b Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 30 Dec 2022 08:27:53 -0800 Subject: [PATCH 02/13] upd --- include/tvm/tir/schedule/schedule.h | 8 +- python/tvm/tir/schedule/schedule.py | 22 ++- src/tir/schedule/concrete_schedule.cc | 20 ++- src/tir/schedule/concrete_schedule.h | 6 +- src/tir/schedule/primitive.h | 8 +- .../schedule/primitive/cache_read_write.cc | 132 ++++++++++++++---- src/tir/schedule/traced_schedule.cc | 34 ++--- src/tir/schedule/traced_schedule.h | 6 +- 8 files changed, 176 insertions(+), 60 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index f322188f52c5..51de89fff0a6 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -414,10 +414,12 @@ class ScheduleNode : public runtime::Object { * \param read_buffer_index The index of the buffer in block's read region. * \param storage_scope The target storage scope. * \param dim_order The user-defined dimension order of allocated buffer. + * \param consumer_blocks An optional list of consumers to read from cache directly. * \return The cache stage block. */ virtual BlockRV ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, Array dim_order) = 0; + const String& storage_scope, Array dim_order, + Array consumer_blocks) = 0; /*! * \brief Create a block that writes a buffer region into a write cache. It requires: * 1) There is only one block who writes the target buffer. @@ -427,10 +429,12 @@ class ScheduleNode : public runtime::Object { * \param write_buffer_index The index of the buffer in block's write region * \param storage_scope The target storage scope * \param dim_order The user-defined dimension order of allocated buffer. + * \param consumer_blocks An optional list of consumers to read from cache directly. * \return The cache stage block. */ virtual BlockRV ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, Array dim_order) = 0; + const String& storage_scope, Array dim_order, + Array consumer_blocks) = 0; /*! * \brief Create 2 blocks that read&write a buffer region into a read/write cache. * It requires the the target block both read & write the target buffer. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index cec6b5f70508..69dc534107ad 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1200,7 +1200,8 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: @type_checked def reverse_cache_read( - self, block: BlockRV, read_buffer_index: int, storage_scope: str, dim_order: List[int] = [] + self, block: BlockRV, read_buffer_index: int, storage_scope: str, dim_order: List[int] = [], + consumer_blocks=None, ) -> BlockRV: """Create a block that reads a buffer region into a read cache. The index mapping was performed at producer rather than consumer. @@ -1215,18 +1216,25 @@ def reverse_cache_read( dim_order: List[int] The user-defined dimension order of allocated buffer. Numbers indicate the index of block iter vars. + consumer_blocks: Optional[List[Union[BlockRV, str]]] + An optional list of consumers that should read directly from the cache. + If not specified, all consumers will read from the original buffer. + Returns ------- cached_block : BlockRV The block of the cache stage """ + if consumer_blocks is None: + consumer_blocks = [] return _ffi_api.ScheduleReverseCacheRead( # type: ignore # pylint: disable=no-member - self, block, read_buffer_index, storage_scope, dim_order + self, block, read_buffer_index, storage_scope, dim_order, consumer_blocks ) @type_checked def reverse_cache_write( - self, block: BlockRV, write_buffer_index: int, storage_scope: str, dim_order: List[int] = [] + self, block: BlockRV, write_buffer_index: int, storage_scope: str, dim_order: List[int] = [], + consumer_blocks=None, ) -> BlockRV: """Create a block that reads a buffer region into a write cache. The index mapping was performed at consumer rather than producer. @@ -1241,13 +1249,19 @@ def reverse_cache_write( dim_order: List[int] The user-defined dimension order of allocated buffer. Numbers indicate the index of block iter vars. + consumer_blocks: Optional[List[Union[BlockRV, str]]] + An optional list of consumers that should read directly from the cache. + If not specified, all consumers will read from the original buffer. + Returns ------- cached_block : BlockRV The block of the cache stage """ + if consumer_blocks is None: + consumer_blocks = [] return _ffi_api.ScheduleReverseCacheWrite( # type: ignore # pylint: disable=no-member - self, block, write_buffer_index, storage_scope, dim_order + self, block, write_buffer_index, storage_scope, dim_order, consumer_blocks ) @type_checked diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 9f6868a85ee3..5ebb503ed16f 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -570,11 +570,17 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff BlockRV ConcreteScheduleNode::ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope, - Array dim_order) { + Array dim_order, + const Array consumer_blocks) { StmtSRef result{nullptr}; + // Create a new array of SRefs from the consumer block list. + Array consumer_block_refs = {}; + for (BlockRV block : consumer_blocks) { + consumer_block_refs.push_back(this->GetSRef(block)); + } TVM_TIR_SCHEDULE_BEGIN(); result = tir::ReverseCacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope, - dim_order); + dim_order, consumer_block_refs); TVM_TIR_SCHEDULE_END("reverse-cache-read", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); @@ -582,11 +588,17 @@ BlockRV ConcreteScheduleNode::ReverseCacheRead(const BlockRV& block_rv, int read BlockRV ConcreteScheduleNode::ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, - Array dim_order) { + Array dim_order, + const Array consumer_blocks) { StmtSRef result{nullptr}; + // Create a new array of SRefs from the consumer block list. + Array consumer_block_refs = {}; + for (BlockRV block : consumer_blocks) { + consumer_block_refs.push_back(this->GetSRef(block)); + } TVM_TIR_SCHEDULE_BEGIN(); result = tir::ReverseCacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, - storage_scope, dim_order); + storage_scope, dim_order, consumer_block_refs); TVM_TIR_SCHEDULE_END("reverse-cache-write", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 20d7ce9cc7ec..c5cf138d0eed 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -117,9 +117,11 @@ class ConcreteScheduleNode : public ScheduleNode { BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, const Array consumer_blocks = {}) override; BlockRV ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, Array dim_order) override; + const String& storage_scope, Array dim_order = {}, + Array consumer_blocks = {}) override; BlockRV ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, Array dim_order) override; + const String& storage_scope, Array dim_order = {}, + Array consumer_blocks = {}) override; Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) override; Array CacheIndex(const BlockRV& block_rv, int write_buffer_index) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 4685a0a7066a..8fd322a0be92 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -279,11 +279,13 @@ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int * \param read_buffer_index The index of the buffer in block's read region. * \param storage_scope The target storage scope. * \param dim_order The user-defined dimension order of allocated buffer. + * \param consumer_blocks Array of blocks that consume the cache. * \return The cache stage block. */ TVM_DLL StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, const String& storage_scope, - Array dim_order); + Array dim_order = {}, + Array consumer_blocks = {}); /*! * \brief Create a block that writes a buffer region into a write cache. It requires: * 1) There is only one block that writes the target buffer. @@ -294,11 +296,13 @@ TVM_DLL StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref * \param write_buffer_index The index of the buffer in block's write region * \param storage_scope The target storage scope * \param dim_order The user-defined dimension order of allocated buffer. + * \param consumer_blocks Array of blocks that consume the cache. * \return The cache stage block. */ TVM_DLL StmtSRef ReverseCacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, const String& storage_scope, - Array dim_order); + Array dim_order = {}, + Array consumer_blocks = {}); /*! *! diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 25acc0218c4a..46a9b409a9b9 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -764,6 +764,19 @@ class ReverseCacheReadRewriter : public StmtExprMutator { Stmt VisitStmt_(const BlockNode* block) final { Block old_stmt = GetRef(block); + // Check if this block is one of the specified consumers. + // If no consumer blocks are specified, all blocks should be considered consumers. + bool is_consumer = info_->consumer_blocks.empty(); + // Otherwise check if this is one of the specified blocks. + for (StmtSRef consumer_sref : info_->consumer_blocks) { + const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref); + Block consumer_block = GetRef(consumer_node); + if (old_stmt.same_as(consumer_block)) { + is_consumer = true; + } + } + // Keep track of this blocks status. We'll use this when rewriting loads. + current_block_consumes = is_consumer; if (block != scope_sref_->stmt && GetBufferRegionFromBuffer(block->writes, info_->read_buffer).defined()) { return std::move(old_stmt); @@ -785,24 +798,27 @@ class ReverseCacheReadRewriter : public StmtExprMutator { stmt = Block(n); } else { // Otherwise, update read regions and match_buffers - Array reads; - for (const BufferRegion& buf_region : block->reads) { - if (buf_region->buffer.same_as(info_->read_buffer)) { - Array region; - for (const PrimExpr index : new_indices_) { - region.push_back(Range::FromMinExtent(index, Integer(1))); + // Only make this change if the block is one of the specified consumers. + if (is_consumer) { + Array reads; + for (const BufferRegion& buf_region : block->reads) { + if (buf_region->buffer.same_as(info_->read_buffer)) { + Array region; + for (const PrimExpr index : new_indices_) { + region.push_back(Range::FromMinExtent(index, Integer(1))); + } + reads.push_back(BufferRegion(info_->write_buffer, region)); + } else { + reads.push_back(buf_region); } - reads.push_back(BufferRegion(info_->write_buffer, region)); - } else { - reads.push_back(buf_region); } - } - // NOTE(Zihao): do not process match buffers for now. - if (!reads.same_as(block->reads)) { - ObjectPtr n = make_object(*stmt.as()); - n->reads = std::move(reads); - stmt = Block(n); + // NOTE(Zihao): do not process match buffers for now. + if (!reads.same_as(block->reads)) { + ObjectPtr n = make_object(*stmt.as()); + n->reads = std::move(reads); + stmt = Block(n); + } } } info_->block_reuse.Set(old_stmt, stmt); @@ -817,7 +833,7 @@ class ReverseCacheReadRewriter : public StmtExprMutator { } PrimExpr VisitExpr_(const BufferLoadNode* load) final { - if (load->buffer.same_as(info_->read_buffer)) { + if (load->buffer.same_as(info_->read_buffer) && current_block_consumes) { ObjectPtr n = make_object(*load); n->buffer = info_->write_buffer; n->indices = new_indices_; @@ -832,6 +848,8 @@ class ReverseCacheReadRewriter : public StmtExprMutator { CacheStageInfo* info_; /*! \brief The indices to use for new buffer. */ Array new_indices_; + /*! \brief Whether the most recently visited block is a specified consumer. */ + bool current_block_consumes; }; /*! \brief Mutator for CacheRead. */ @@ -996,6 +1014,30 @@ class ReverseCacheWriteRewriter : public StmtExprMutator { Stmt VisitStmt_(const BlockNode* block) final { Block old_stmt = GetRef(block); + + // Check if this block is one of the specified cache consumers. + // update the read buffer to the cache. + for (StmtSRef consumer_sref : info_->consumer_blocks) { + const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref); + Block consumer_block = GetRef(consumer_node); + if (old_stmt.same_as(consumer_block)) { + Array reads = + ReplaceBuffer(block->reads, info_->write_buffer, info_->read_buffer); + Array match_buffers = + ReplaceBuffer(block->match_buffers, info_->write_buffer, info_->read_buffer); + if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { + auto n = CopyOnWrite(block); + n->reads = std::move(reads); + n->match_buffers = std::move(match_buffers); + n->body = VisitStmt(block->body); + Block new_consumer = Block(n); + info_->block_reuse.Set(old_stmt, new_consumer); + return std::move(new_consumer); + } + return std::move(old_stmt); + } + } + // We only mutate the block which generates info->write_buffer if (block != writer_block_sref_->stmt && block != scope_sref_->stmt && !under_writer_block_) { return std::move(old_stmt); @@ -1697,7 +1739,8 @@ Array GetLoopsUnderScope(const StmtSRef& block_sref, const StmtSRef& t } StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, - const String& storage_scope, Array dim_order) { + const String& storage_scope, Array dim_order, + const Array consumer_blocks) { /*! * Check: * - The index is in the array of block reading region @@ -1723,6 +1766,14 @@ StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int re CacheStageInfo info; info.read_buffer = read_buffer; + // info.consumer_blocks indicates which buffers should consume the cache. + for (auto consumer : consumer_blocks) { + info.consumer_blocks.insert(consumer); + for (auto child : tir::GetChildBlocks(self, consumer)) { + info.consumer_blocks.insert(child); + } + } + // Step 3. Update cache stage info. Optional maybe_region = GetBufferRegionFromBuffer(block->reads, read_buffer); ICHECK(maybe_region.defined()) << read_buffer @@ -1819,7 +1870,8 @@ StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int re } StmtSRef ReverseCacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, - const String& storage_scope, Array dim_order) { + const String& storage_scope, Array dim_order, + const Array consumer_blocks) { /*! * Check: * - The index is in the array of block reading region @@ -1844,6 +1896,14 @@ StmtSRef ReverseCacheWrite(ScheduleState self, const StmtSRef& block_sref, int w CacheStageInfo info; info.write_buffer = write_buffer; + // info.consumer_blocks indicates which buffers should consume the cache. + for (auto consumer : consumer_blocks) { + info.consumer_blocks.insert(consumer); + for (auto child : tir::GetChildBlocks(self, consumer)) { + info.consumer_blocks.insert(child); + } + } + // Step 3. Check the only writer block. ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get()); @@ -2247,21 +2307,29 @@ struct ReverseCacheReadTraits : public UnpackedInstTraits dim_order) { - return sch->ReverseCacheRead(block, read_buffer_index->value, storage_scope, dim_order); + String storage_scope, Array dim_order, + Array consumer_blocks) { + return sch->ReverseCacheRead(block, read_buffer_index->value, storage_scope, dim_order, + consumer_blocks); } static String UnpackedAsPython(Array outputs, String block, Integer read_buffer_index, - String storage_scope, Array dim_order) { + String storage_scope, Array dim_order, + Array consumer_blocks) { PythonAPICall py("reverse_cache_read"); py.Input("block", block); py.Input("read_buffer_index", read_buffer_index->value); py.Input("storage_scope", storage_scope); - py.Input("dim_order", dim_order); + if (!dim_order.empty()) { + py.Input("dim_order", dim_order); + } + if (!consumer_blocks.empty()) { + py.Input("consumer_blocks", consumer_blocks); + } py.SingleOutput(outputs); return py.Str(); } @@ -2276,21 +2344,29 @@ struct ReverseCacheWriteTraits : public UnpackedInstTraits dim_order) { - return sch->ReverseCacheWrite(block, write_buffer_index->value, storage_scope, dim_order); + String storage_scope, Array dim_order, + Array consumer_blocks) { + return sch->ReverseCacheWrite(block, write_buffer_index->value, storage_scope, dim_order, + consumer_blocks); } static String UnpackedAsPython(Array outputs, String block, Integer write_buffer_index, - String storage_scope, Array dim_order) { + String storage_scope, Array dim_order, + Array consumer_blocks) { PythonAPICall py("reverse_cache_write"); py.Input("block", block); py.Input("write_buffer_index", write_buffer_index->value); py.Input("storage_scope", storage_scope); - py.Input("dim_order", dim_order); + if (!dim_order.empty()) { + py.Input("dim_order", dim_order); + } + if (!dim_order.empty()) { + py.Input("consumer_blocks", consumer_blocks); + } py.SingleOutput(outputs); return py.Str(); } diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 079e5057f7b1..879e6def5815 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -310,32 +310,34 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer } BlockRV TracedScheduleNode::ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, - Array dim_order) { - BlockRV result = - ConcreteScheduleNode::ReverseCacheRead(block_rv, read_buffer_index, storage_scope, dim_order); + const String& storage_scope, Array dim_order, + Array consumer_blocks) { + BlockRV result = ConcreteScheduleNode::ReverseCacheRead( + block_rv, read_buffer_index, storage_scope, dim_order, consumer_blocks); static const InstructionKind& kind = InstructionKind::Get("ReverseCacheRead"); trace_->Append( - /*inst=*/Instruction(/*kind=*/kind, - /*inputs=*/{block_rv}, - /*attrs=*/{Integer(read_buffer_index), storage_scope, dim_order}, - /*outputs=*/{result})); + /*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(read_buffer_index), storage_scope, dim_order, consumer_blocks}, + /*outputs=*/{result})); return result; } BlockRV TracedScheduleNode::ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, - Array dim_order) { - BlockRV result = ConcreteScheduleNode::ReverseCacheWrite(block_rv, write_buffer_index, - storage_scope, dim_order); + const String& storage_scope, Array dim_order, + Array consumer_blocks) { + BlockRV result = ConcreteScheduleNode::ReverseCacheWrite( + block_rv, write_buffer_index, storage_scope, dim_order, consumer_blocks); static const InstructionKind& kind = InstructionKind::Get("ReverseCacheWrite"); trace_->Append( - /*inst=*/Instruction(/*kind=*/kind, - /*inputs=*/{block_rv}, - /*attrs=*/{Integer(write_buffer_index), storage_scope, dim_order}, - /*outputs=*/{result})); + /*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(write_buffer_index), storage_scope, dim_order, consumer_blocks}, + /*outputs=*/{result})); return result; } diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 2c6a6f88827a..efd6311e34c2 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -77,9 +77,11 @@ class TracedScheduleNode : public ConcreteScheduleNode { BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, const Array consumer_blocks = {}) final; BlockRV ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, Array dim_order) final; + const String& storage_scope, Array dim_order = {}, + Array consumer_blocks = {}) final; BlockRV ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, Array dim_order) final; + const String& storage_scope, Array dim_order = {}, + Array consumer_blocks = {}) final; Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) final; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, From 88352259ed6425d9e5d6172d2771cd6eaae45415 Mon Sep 17 00:00:00 2001 From: Zihao Date: Tue, 28 Feb 2023 12:06:25 -0800 Subject: [PATCH 03/13] update --- include/tvm/tir/schedule/schedule.h | 16 +- python/tvm/tir/schedule/schedule.py | 134 ++- src/tir/schedule/concrete_schedule.cc | 16 +- src/tir/schedule/concrete_schedule.h | 8 +- src/tir/schedule/primitive.h | 16 +- .../schedule/primitive/cache_read_write.cc | 925 ++++++++---------- src/tir/schedule/schedule.cc | 8 +- src/tir/schedule/traced_schedule.cc | 24 +- src/tir/schedule/traced_schedule.h | 8 +- 9 files changed, 580 insertions(+), 575 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index add4c7e42eca..d3329a6a0339 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -409,31 +409,31 @@ class ScheduleNode : public runtime::Object { * \brief Create a block that reads a buffer region into a read cache. It requires: * 1) There is at most one block who writes the buffer in the scope. * 2) The scope block have stage-pipeline property. - * Compared to cache read, the index mapping was performed at producer rather than consumer. + * Compared to cache read, the indices to access allocated cache buffer is customized by user. * \param block_rv The consumer block of the target buffer. * \param read_buffer_index The index of the buffer in block's read region. * \param storage_scope The target storage scope. - * \param dim_order The user-defined dimension order of allocated buffer. + * \param index_map User defined indices to access allocated cache buffer, maps from block iter vars. * \param consumer_blocks An optional list of consumers to read from cache directly. * \return The cache stage block. */ - virtual BlockRV ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, Array dim_order, + virtual BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope, const IndexMap& index_map, Array consumer_blocks) = 0; /*! * \brief Create a block that writes a buffer region into a write cache. It requires: * 1) There is only one block who writes the target buffer. * 2) The scope block have stage-pipeline property. - * Compared to cache write, the index mapping was performed at consumer rather than producer. + * Compared to cache write, the indices to access allocated cache buffer is customized by user. * \param block_rv The producer of the buffer * \param write_buffer_index The index of the buffer in block's write region * \param storage_scope The target storage scope - * \param dim_order The user-defined dimension order of allocated buffer. + * \param index_map User defined indices to access allocated cache buffer, maps from block iter vars. * \param consumer_blocks An optional list of consumers to read from cache directly. * \return The cache stage block. */ - virtual BlockRV ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, Array dim_order, + virtual BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope, const IndexMap& index_map, Array consumer_blocks) = 0; /*! * \brief Create 2 blocks that read&write a buffer region into a read/write cache. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index ab8555ee0f64..44c5ca4fb4ca 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1204,12 +1204,15 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: ) @type_checked - def reverse_cache_read( - self, block: BlockRV, read_buffer_index: int, storage_scope: str, dim_order: List[int] = [], + def reindex_cache_read( + self, block: BlockRV, read_buffer_index: int, storage_scope: str, + index_map: Union[IndexMap, Callable], consumer_blocks=None, ) -> BlockRV: - """Create a block that reads a buffer region into a read cache. - The index mapping was performed at producer rather than consumer. + """Create a block that reads a buffer region into a read cache, with user customized + indices specified by index map. + The read region of the buffer to read in the block must be a single point. + Parameters ---------- block : BlockRV @@ -1218,9 +1221,8 @@ def reverse_cache_read( The index of the buffer in block's read region. storage_scope: str The target storage scope. - dim_order: List[int] - The user-defined dimension order of allocated buffer. - Numbers indicate the index of block iter vars. + index_map: Union[IndexMap, Callable] + User defined indices to access allocated cache buffer, maps from block iter vars. consumer_blocks: Optional[List[Union[BlockRV, str]]] An optional list of consumers that should read directly from the cache. If not specified, all consumers will read from the original buffer. @@ -1229,20 +1231,68 @@ def reverse_cache_read( ------- cached_block : BlockRV The block of the cache stage + + Examples + -------- + Before reindex_cache_read, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_reindex_cache_read(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and reindex_cache_read: + + .. code-block:: python + + sch = tir.Schedule(before_cache_read) + block_b = sch.get_block("B") + sch.reindex_cache_read(block_b, 0, "local", lambda vi, vj: (vj, vi)) + print(sch.mod["main"].script()) + + After applying reindex_cache_read, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_reindex_cache_read(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + A_local = T.alloc_buffer((128, 128), scope="local") + for i, j in T.grid(128, 128): + with T.block("A_local"): + vi, vj = T.axis.remap("SS", [i, j]) + A_local[vj, vi] = A[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A_local[vj, vi] * 2.0 + """ if consumer_blocks is None: consumer_blocks = [] - return _ffi_api.ScheduleReverseCacheRead( # type: ignore # pylint: disable=no-member - self, block, read_buffer_index, storage_scope, dim_order, consumer_blocks + if callable(index_map): + index_map = IndexMap.from_func(index_map) + return _ffi_api.ScheduleReindexCacheRead( # type: ignore # pylint: disable=no-member + self, block, read_buffer_index, storage_scope, index_map, consumer_blocks ) @type_checked - def reverse_cache_write( - self, block: BlockRV, write_buffer_index: int, storage_scope: str, dim_order: List[int] = [], + def reindex_cache_write( + self, block: BlockRV, write_buffer_index: int, storage_scope: str, + index_map: Union[Callable, IndexMap], consumer_blocks=None, ) -> BlockRV: - """Create a block that reads a buffer region into a write cache. - The index mapping was performed at consumer rather than producer. + """Create a block that reads a buffer region into a write cache, with user customized + indices specified by index map. + The write region of the buffer to write in the block must be a single point. + Parameters ---------- block : BlockRV @@ -1251,9 +1301,8 @@ def reverse_cache_write( The index of the buffer in block's write region. storage_scope: str The target storage scope. - dim_order: List[int] - The user-defined dimension order of allocated buffer. - Numbers indicate the index of block iter vars. + index_map: Union[Callable, IndexMap] + User defined indices to access allocated cache buffer, maps from block iter vars. consumer_blocks: Optional[List[Union[BlockRV, str]]] An optional list of consumers that should read directly from the cache. If not specified, all consumers will read from the original buffer. @@ -1262,11 +1311,56 @@ def reverse_cache_write( ------- cached_block : BlockRV The block of the cache stage + + Examples + -------- + Before reindex_cache_write, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_reindex_cache_write(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and reindex_cache_write: + + .. code-block:: python + + sch = tir.Schedule(before_cache_write) + block_b = sch.get_block("B") + sch.reindex_cache_write(block_b, 0, "local", lambda vi, vj: (vi // 2, vi % 2, vj)) + print(sch.mod["main"].script()) + + After applying reindex_cache_write, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_cache_write(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (64, 2, 128)) + B_local = T.alloc_buffer((128, 128), scope="local") + for i, j in T.grid(128, 128): + with T.block("A_local"): + vi, vj = T.axis.remap("SS", [i, j]) + B_local[vi % 2, vi // 2, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = B_local[vi % 2, vi // 2, vj] + """ if consumer_blocks is None: consumer_blocks = [] - return _ffi_api.ScheduleReverseCacheWrite( # type: ignore # pylint: disable=no-member - self, block, write_buffer_index, storage_scope, dim_order, consumer_blocks + if callable(index_map): + index_map = IndexMap.from_func(index_map) + return _ffi_api.ScheduleReindexCacheWrite( # type: ignore # pylint: disable=no-member + self, block, write_buffer_index, storage_scope, index_map, consumer_blocks ) @type_checked @@ -1491,7 +1585,7 @@ def reindex( Examples -------- - Before transform_layout, in TensorIR, the IR is: + Before reindex, in TensorIR, the IR is: .. code-block:: python @@ -1505,7 +1599,7 @@ def before_reindex( vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vj, vi] * 2.0 - Create the schedule and do transform_layout: + Create the schedule and do reindex: .. code-block:: python diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 3b4dee75ca1d..89c454178f6d 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -568,9 +568,9 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff return CreateRV(result); } -BlockRV ConcreteScheduleNode::ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, +BlockRV ConcreteScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope, - Array dim_order, + const IndexMap& index_map, const Array consumer_blocks) { StmtSRef result{nullptr}; // Create a new array of SRefs from the consumer block list. @@ -579,16 +579,16 @@ BlockRV ConcreteScheduleNode::ReverseCacheRead(const BlockRV& block_rv, int read consumer_block_refs.push_back(this->GetSRef(block)); } TVM_TIR_SCHEDULE_BEGIN(); - result = tir::ReverseCacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope, - dim_order, consumer_block_refs); + result = tir::ReindexCacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope, + index_map, consumer_block_refs); TVM_TIR_SCHEDULE_END("reverse-cache-read", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); } -BlockRV ConcreteScheduleNode::ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, +BlockRV ConcreteScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, - Array dim_order, + const IndexMap& index_map, const Array consumer_blocks) { StmtSRef result{nullptr}; // Create a new array of SRefs from the consumer block list. @@ -597,8 +597,8 @@ BlockRV ConcreteScheduleNode::ReverseCacheWrite(const BlockRV& block_rv, int wri consumer_block_refs.push_back(this->GetSRef(block)); } TVM_TIR_SCHEDULE_BEGIN(); - result = tir::ReverseCacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, - storage_scope, dim_order, consumer_block_refs); + result = tir::ReindexCacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, + storage_scope, index_map, consumer_block_refs); TVM_TIR_SCHEDULE_END("reverse-cache-write", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 6e34d1e14a29..bb4c712f1e28 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -116,11 +116,11 @@ class ConcreteScheduleNode : public ScheduleNode { const Array consumer_blocks = {}) override; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, const Array consumer_blocks = {}) override; - BlockRV ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, Array dim_order = {}, + BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope, const IndexMap& index_map, Array consumer_blocks = {}) override; - BlockRV ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, Array dim_order = {}, + BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope, const IndexMap& index_map, Array consumer_blocks = {}) override; Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index b535ca92dcbb..1ad75c313ce0 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -273,35 +273,35 @@ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int * \brief Create a block that reads a buffer region into a read cache. It requires: * 1) There is at most one block who writes the buffer in the scope. * 2) The scope block have stage-pipeline property. - * Compared to cache read, the index mapping was performed at producer instead of consumer. + * Compared to cache read, the indices to access allocated cache buffer is customized by user. * \param self The state of the schedule * \param block_sref The consumer block of the target buffer. * \param read_buffer_index The index of the buffer in block's read region. * \param storage_scope The target storage scope. - * \param dim_order The user-defined dimension order of allocated buffer. + * \param index_map User defined indices to access allocated cache buffer, maps from block iter vars. * \param consumer_blocks Array of blocks that consume the cache. * \return The cache stage block. */ -TVM_DLL StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, +TVM_DLL StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, const String& storage_scope, - Array dim_order = {}, + const IndexMap& index_map, Array consumer_blocks = {}); /*! * \brief Create a block that writes a buffer region into a write cache. It requires: * 1) There is only one block that writes the target buffer. * 2) The scope block have stage-pipeline property. - * Compared to cache write, the index mapping was performed at producer instead of consumer. + * Compared to cache write, the indices to access allocated cache buffer is customized by user. * \param self The state of the schedule * \param block_sref The producer of the buffer * \param write_buffer_index The index of the buffer in block's write region * \param storage_scope The target storage scope - * \param dim_order The user-defined dimension order of allocated buffer. + * \param index_map User defined indices to access allocated cache buffer, maps from block iter vars. * \param consumer_blocks Array of blocks that consume the cache. * \return The cache stage block. */ -TVM_DLL StmtSRef ReverseCacheWrite(ScheduleState self, const StmtSRef& block_sref, +TVM_DLL StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, const String& storage_scope, - Array dim_order = {}, + const IndexMap& index_map, Array consumer_blocks = {}); /*! diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 6d6cb6948468..0c42b7f09844 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -94,48 +94,60 @@ Optional GetBufferRegionFromBuffer(const Array& buff return res; } -struct ReverseCacheTouchedInfo { - /* Whether read or write. */ - bool read; +struct ReindexCacheStageInfo : CacheStageInfo { + /* Indices used to access the allocated cache buffer. */ + Array indices; /* Touched loop variable related information. */ Array loop_vars; Array loop_ranges; /* Touched block variable related information. */ - Array block_vars; - Array iter_values; + Array block_iter_vars; + Array block_iter_values; }; -// /*! \brief Return the buffer region related with the buffer */ -// Optional GetBufferRegionFromBuffer(const Array& buffer_regions, -// const Buffer& buffer) { -// Optional res = NullOpt; -// for (const auto& region : buffer_regions) { -// if (region->buffer.same_as(buffer)) { -// ICHECK(!res.defined()); -// res = region; -// } -// } -// return res; -// } +/* \brief The schedule error that accessed buffer region is not a single point for + * reindex_cache_read/write. */ +class NotSinglePointAccess : public ScheduleError { + public: + explicit NotSinglePointAccess(IRModule mod, Block block, BufferRegion cache_region, + bool is_cache_read) + : mod_(std::move(mod)), block_(std::move(block)), cache_region_(cache_region) { + primitive_name_ = is_cache_read ? "reindex_cache_read" : "reindex_cache_write"; + } + + String FastErrorString() const final { + return "ScheduleError: The buffer region accessed inside the block is not a single point."; + } + + String DetailRenderTemplate() const final { + std::stringstream s; + s << "The buffer region " << cache_region_ + << " accessed inside block {0} is not a single point, which violates" + << " the prerequisite of " << primitive_name_ << " primitive."; + return String(s.str()); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + BufferRegion cache_region_; + String primitive_name_; +}; /*! - * \brief Create a loop nest that represents reverse cache copy (reverse_read / reverse_write) from - * read buffer to write buffer. \param cache_region The cached copy region. \param info The cache - * stage information, which will be updated in the function. \param storage_scope The storage scope - * of the cached buffer (only used in naming here) \returns A block indicating the body of the loop - * nesting. + * \brief Create a loop nest that represents reindex cache copy (reindex_cache_read / + * reindex_cache_write) from read buffer to write buffer. + * \param cache_region The cached copy region. + * \param info The cache stage information, which will be updated in the function. + * \param storage_scope The storage scope of the cached buffer (only used in naming here) + * \returns A block indicating the body of the loop nesting. */ -Block MakeReverseCacheStage(const BufferRegion& cache_region, ReverseCacheTouchedInfo* touched_info, - CacheStageInfo* info, const String& storage_scope) { - /* Check whether cache region is a single point. */ - bool single_point = true; - for (const Range& range : cache_region->region) { - const auto* ext_int = range->extent.as(); - if (!ext_int || ext_int->value != 1) { - single_point = false; - } - } - CHECK(single_point) << "ReverseCacheStage is only valid when cache region is a single point."; +template +Block MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStageInfo* info, + const String& storage_scope) { // loop variables std::vector loop_vars; // block variables @@ -144,15 +156,15 @@ Block MakeReverseCacheStage(const BufferRegion& cache_region, ReverseCacheTouche std::vector iter_values; // Create loop vars and block vars' binding_value Map var_map; - for (size_t i = 0; i < touched_info->loop_vars.size(); ++i) { - Var original_var = touched_info->loop_vars[i]; + for (size_t i = 0; i < info->loop_vars.size(); ++i) { + Var original_var = info->loop_vars[i]; Var loop_var("ax" + std::to_string(i), original_var.dtype()); var_map.Set(original_var, loop_var); loop_vars.push_back(loop_var); } - for (size_t i = 0; i < touched_info->block_vars.size(); ++i) { - IterVar original_block_var = touched_info->block_vars[i]; - PrimExpr original_iter_value = touched_info->iter_values[i]; + for (size_t i = 0; i < info->block_iter_vars.size(); ++i) { + IterVar original_block_var = info->block_iter_vars[i]; + PrimExpr original_iter_value = info->block_iter_values[i]; IterVar block_var = IterVar( /*dom=*/original_block_var->dom, /*var=*/Var("v" + std::to_string(i), original_block_var->var.dtype()), @@ -166,23 +178,17 @@ Block MakeReverseCacheStage(const BufferRegion& cache_region, ReverseCacheTouche Region read_access_region, write_access_region; Array read_access_indices, write_access_indices; // Compute read/write region and read/write access indices. + Array& old_indices = (is_cache_read) ? read_access_indices : write_access_indices; + Region& old_region = (is_cache_read) ? read_access_region : write_access_region; for (const Range& range : cache_region->region) { - if (touched_info->read) { - read_access_indices.push_back(Substitute(range->min, var_map)); - read_access_region.push_back(Range::FromMinExtent(read_access_indices.back(), Integer(1))); - } else { - write_access_indices.push_back(Substitute(range->min, var_map)); - write_access_region.push_back(Range::FromMinExtent(write_access_indices.back(), Integer(1))); - } + old_indices.push_back(Substitute(range->min, var_map)); + old_region.push_back(Range::FromMinExtent(old_indices.back(), Integer(1))); } - for (const IterVar& block_var : block_vars) { - if (touched_info->read) { - write_access_indices.push_back(block_var->var); - write_access_region.push_back(Range::FromMinExtent(write_access_indices.back(), Integer(1))); - } else { - read_access_indices.push_back(block_var->var); - read_access_region.push_back(Range::FromMinExtent(read_access_indices.back(), Integer(1))); - } + Array& new_indices = (is_cache_read) ? write_access_indices : read_access_indices; + Region& new_region = (is_cache_read) ? write_access_region : read_access_region; + for (const PrimExpr& idx : info->indices) { + new_indices.push_back(Substitute((idx), var_map)); + new_region.push_back(Range::FromMinExtent(new_indices.back(), Integer(1))); } // Create New Block @@ -205,8 +211,8 @@ Block MakeReverseCacheStage(const BufferRegion& cache_region, ReverseCacheTouche // Create surrounding loops for (size_t i = loop_vars.size(); i >= 1; --i) { body = For(/*loop_var=*/loop_vars[i - 1], - /*min=*/touched_info->loop_ranges[i - 1]->min, - /*extent=*/touched_info->loop_ranges[i - 1]->extent, + /*min=*/info->loop_ranges[i - 1]->min, + /*extent=*/info->loop_ranges[i - 1]->extent, /*kind=*/ForKind::kSerial, /*body=*/body); } @@ -498,9 +504,11 @@ class CacheLocDetector : public StmtVisitor { public: /*! * \brief Detect the insertion position of the cache stage, and write the position into the - * CacheStageInfo \param self The state of the schedule \param block_sref The sref of the unique - * writer block of the buffer being applied cache_read or cache_write \param scope_sref The sref - * of the scope block of the cached block \param info The cache stage info. + * CacheStageInfo + * \param self The state of the schedule + * \param block_sref The sref of the unique writer block of the buffer being applied cache_read or cache_write + * \param scope_sref The sref of the scope block of the cached block + * \param info The cache stage info. */ template static void Detect(const ScheduleState& self, const StmtSRef& block_sref, @@ -553,8 +561,9 @@ class CacheLocDetector : public StmtVisitor { * \brief Constructor * \param self The state of the schedule * \param block_sref The sref of the unique writer block of the buffer being applied cache_read or - * cache_write \param scope_sref The sref of the scope block of the cached block \param - * related_blocks Producer blocks for cache_write, or consumer blocks for cache_read + * cache_write + * \param scope_sref The sref of the scope block of the cached block + * \param related_blocks Producer blocks for cache_write, or consumer blocks for cache_read */ CacheLocDetector(const ScheduleState self, const StmtSRef& block_sref, const StmtSRef& scope_sref, const std::vector& related_blocks) @@ -645,9 +654,11 @@ class CacheInplaceLocDetector : public StmtVisitor { public: /*! * \brief Detect the insertion position of the cache stage, and write the position into the - * CacheStageInfo \param self The state of the schedule \param block_sref The sref of the unique - * block of the buffer being applied cache_inplace \param scope_sref The sref - * of the scope block of the cached block \param info The cache stage info. + * CacheStageInfo + * \param self The state of the schedule + * \param block_sref The sref of the unique block of the buffer being applied cache_inplace + * \param scope_sref The sref of the scope block of the cached block + * \param info The cache stage info. */ static void Detect(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_sref, CacheStageInfo* info) { @@ -726,137 +737,7 @@ class CacheInplaceLocDetector : public StmtVisitor { int loc_pos_{-1}; }; -/*! \brief Mutator for ReverseCacheRead. */ -class ReverseCacheReadRewriter : public StmtExprMutator { - public: - /*! - * \brief Rewrite the AST and add a cache_read stage with the information provided. - * \param scope_sref The parent scope of this mutation. - * \param info The cache stage information. - * \param touched_info The reverse cache touched information. - * \return The new AST rooting at the original parent scope. - */ - static Stmt Rewrite(const StmtSRef& scope_sref, CacheStageInfo* info, - ReverseCacheTouchedInfo* touched_info) { - ReverseCacheReadRewriter rewriter(scope_sref, info, touched_info); - return rewriter(GetRef(scope_sref->stmt)); - } - - private: - explicit ReverseCacheReadRewriter(const StmtSRef& scope_sref, CacheStageInfo* info, - ReverseCacheTouchedInfo* touched_info) - : scope_sref_(scope_sref), info_(info) { - for (const IterVar& iter_var : touched_info->block_vars) { - new_indices_.push_back(iter_var->var); - } - } - - Stmt VisitStmt_(const ForNode* loop) final { - Stmt stmt = StmtMutator::VisitStmt_(loop); - // Check the insertion point - if (loop == info_->loc_sref->stmt) { - // Insert cache stage into the loop if it is the right place - ObjectPtr n = make_object(*stmt.as()); - n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); - stmt = Stmt(n); - } - return stmt; - } - - Stmt VisitStmt_(const BlockRealizeNode* block_realize) final { - Stmt stmt = StmtMutator::VisitStmt_(block_realize); - return stmt; - } - - Stmt VisitStmt_(const BlockNode* block) final { - Block old_stmt = GetRef(block); - // Check if this block is one of the specified consumers. - // If no consumer blocks are specified, all blocks should be considered consumers. - bool is_consumer = info_->consumer_blocks.empty(); - // Otherwise check if this is one of the specified blocks. - for (StmtSRef consumer_sref : info_->consumer_blocks) { - const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref); - Block consumer_block = GetRef(consumer_node); - if (old_stmt.same_as(consumer_block)) { - is_consumer = true; - } - } - // Keep track of this blocks status. We'll use this when rewriting loads. - current_block_consumes = is_consumer; - if (block != scope_sref_->stmt && - GetBufferRegionFromBuffer(block->writes, info_->read_buffer).defined()) { - return std::move(old_stmt); - } - // Mutate the body - Block stmt = Downcast(StmtMutator::VisitStmt_(block)); - // Check the insertion point - if (block == info_->loc_sref->stmt) { - // Insert cache stage into the block if it is the right place - ObjectPtr n = make_object(*stmt.as()); - n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); - stmt = Block(n); - } - // Check if it is the block corresponding to the parent scope - if (block == scope_sref_->stmt) { - // If so, put buffer allocation on the parent scope - ObjectPtr n = make_object(*stmt.as()); - n->alloc_buffers.push_back(info_->alloc.value()); - stmt = Block(n); - } else { - // Otherwise, update read regions and match_buffers - // Only make this change if the block is one of the specified consumers. - if (is_consumer) { - Array reads; - for (const BufferRegion& buf_region : block->reads) { - if (buf_region->buffer.same_as(info_->read_buffer)) { - Array region; - for (const PrimExpr index : new_indices_) { - region.push_back(Range::FromMinExtent(index, Integer(1))); - } - reads.push_back(BufferRegion(info_->write_buffer, region)); - } else { - reads.push_back(buf_region); - } - } - - // NOTE(Zihao): do not process match buffers for now. - if (!reads.same_as(block->reads)) { - ObjectPtr n = make_object(*stmt.as()); - n->reads = std::move(reads); - stmt = Block(n); - } - } - } - info_->block_reuse.Set(old_stmt, stmt); - return std::move(stmt); - } - - PrimExpr VisitExpr_(const VarNode* op) final { - if (op == info_->read_buffer->data.get()) { - return info_->write_buffer->data; - } - return GetRef(op); - } - - PrimExpr VisitExpr_(const BufferLoadNode* load) final { - if (load->buffer.same_as(info_->read_buffer) && current_block_consumes) { - ObjectPtr n = make_object(*load); - n->buffer = info_->write_buffer; - n->indices = new_indices_; - return PrimExpr(n); - } - return ExprMutator::VisitExpr_(load); - } - - /*! \brief The parent scope of the insertion */ - const StmtSRef& scope_sref_; - /*! \brief The info for inserting cache stage */ - CacheStageInfo* info_; - /*! \brief The indices to use for new buffer. */ - Array new_indices_; - /*! \brief Whether the most recently visited block is a specified consumer. */ - bool current_block_consumes; -}; +class ReindexCacheReadRewriter; /*! \brief Mutator for CacheRead. */ class CacheReadRewriter : public StmtExprMutator { @@ -874,7 +755,14 @@ class CacheReadRewriter : public StmtExprMutator { private: explicit CacheReadRewriter(const StmtSRef& scope_sref, CacheStageInfo* info) - : scope_sref_(scope_sref), info_(info) {} + : scope_sref_(scope_sref), info_(info) { + update_access_regions = [&](Array regions) { + return ReplaceBuffer(std::move(regions), info_->read_buffer, info_->write_buffer); + }; + update_match_buffers = [&](Array match_buffers) { + return ReplaceBuffer(std::move(match_buffers), info_->read_buffer, info_->write_buffer); + }; + } Stmt VisitStmt_(const ForNode* loop) final { Stmt stmt = StmtMutator::VisitStmt_(loop); @@ -888,7 +776,7 @@ class CacheReadRewriter : public StmtExprMutator { return stmt; } - Stmt VisitStmt_(const BlockNode* block) final { + Stmt VisitStmt_(const BlockNode* block) override { Block old_stmt = GetRef(block); // Check if this block is one of the specified consumers. // If no consumer blocks are specified, all blocks should be considered consumers. @@ -930,10 +818,8 @@ class CacheReadRewriter : public StmtExprMutator { // Otherwise, update read regions and match_buffers // Only make this change if the block is one of the specified consumers. if (is_consumer) { - Array reads = - ReplaceBuffer(block->reads, info_->read_buffer, info_->write_buffer); - Array match_buffers = - ReplaceBuffer(block->match_buffers, info_->read_buffer, info_->write_buffer); + Array reads = update_access_regions(block->reads); + Array match_buffers = update_match_buffers(block->match_buffers); if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { ObjectPtr n = make_object(*stmt.as()); n->reads = std::move(reads); @@ -946,7 +832,7 @@ class CacheReadRewriter : public StmtExprMutator { return std::move(stmt); } - PrimExpr VisitExpr_(const BufferLoadNode* load) final { + PrimExpr VisitExpr_(const BufferLoadNode* load) override { if (load->buffer.same_as(info_->read_buffer) && current_block_consumes) { ObjectPtr n = make_object(*load); n->buffer = info_->write_buffer; @@ -973,184 +859,82 @@ class CacheReadRewriter : public StmtExprMutator { CacheStageInfo* info_; /*! \brief Whether the most recently visited block is a specified consumer. */ bool current_block_consumes; + /*! \brief function to update read/write region of block being cache read.*/ + std::function(Array)> update_access_regions; + /*! \brief function to update match buffers of block being cache read.*/ + std::function(Array)> update_match_buffers; + + friend ReindexCacheReadRewriter; }; -/*! \brief Mutator for ReverseCacheWrite. */ -class ReverseCacheWriteRewriter : public StmtExprMutator { +/*! \brief Mutator for ReindexCacheRead. */ +class ReindexCacheReadRewriter : public CacheReadRewriter { public: /*! - * \brief Rewrite the AST and add a cache_write stage with the information provided. + * \brief Rewrite the AST and add a cache_read stage with the information provided. * \param scope_sref The parent scope of this mutation. - * \param writer_block_sref The only writer block in the scope. * \param info The cache stage information. - * \param touched_info The reverse cache touched information. * \return The new AST rooting at the original parent scope. */ - static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, - CacheStageInfo* info, ReverseCacheTouchedInfo* touched_info) { - ReverseCacheWriteRewriter rewriter(scope_sref, writer_block_sref, info, touched_info); + static Stmt Rewrite(const StmtSRef& scope_sref, ReindexCacheStageInfo* info) { + ReindexCacheReadRewriter rewriter(scope_sref, info); return rewriter(GetRef(scope_sref->stmt)); } private: - explicit ReverseCacheWriteRewriter(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, - CacheStageInfo* info, ReverseCacheTouchedInfo* touched_info) - : scope_sref_(scope_sref), writer_block_sref_(writer_block_sref), info_(info) { - for (const IterVar& iter_var : touched_info->block_vars) { - new_indices_.push_back(iter_var->var); - } - } - - Stmt VisitStmt_(const ForNode* loop) final { - Stmt stmt = StmtMutator::VisitStmt_(loop); - // Check the insertion point - if (loop == info_->loc_sref->stmt) { - // Insert cache stage into the loop if it is the right place - ObjectPtr n = make_object(*stmt.as()); - n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); - stmt = Stmt(n); - } - return stmt; - } - - Stmt VisitStmt_(const BlockRealizeNode* block_realize) final { - Stmt stmt = StmtMutator::VisitStmt_(block_realize); - return stmt; - } - - Stmt VisitStmt_(const BlockNode* block) final { - Block old_stmt = GetRef(block); - - // Check if this block is one of the specified cache consumers. - // update the read buffer to the cache. - for (StmtSRef consumer_sref : info_->consumer_blocks) { - const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref); - Block consumer_block = GetRef(consumer_node); - if (old_stmt.same_as(consumer_block)) { - Array reads = - ReplaceBuffer(block->reads, info_->write_buffer, info_->read_buffer); - Array match_buffers = - ReplaceBuffer(block->match_buffers, info_->write_buffer, info_->read_buffer); - if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { - auto n = CopyOnWrite(block); - n->reads = std::move(reads); - n->match_buffers = std::move(match_buffers); - n->body = VisitStmt(block->body); - Block new_consumer = Block(n); - info_->block_reuse.Set(old_stmt, new_consumer); - return std::move(new_consumer); - } - return std::move(old_stmt); - } - } - - // We only mutate the block which generates info->write_buffer - if (block != writer_block_sref_->stmt && block != scope_sref_->stmt && !under_writer_block_) { - return std::move(old_stmt); - } - - // Mutate the body - bool under_scope = under_writer_block_ || block == writer_block_sref_->stmt; - std::swap(under_scope, under_writer_block_); - Block stmt = Downcast(StmtMutator::VisitStmt_(block)); - std::swap(under_scope, under_writer_block_); - - // Find the insertion point - if (block == info_->loc_sref->stmt) { - ObjectPtr n = make_object(*stmt.as()); - n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); - stmt = Block(n); - } - // Put buffer allocation on the parent scope - if (block == scope_sref_->stmt) { - ObjectPtr n = make_object(*stmt.as()); - n->alloc_buffers.push_back(info_->alloc.value()); - stmt = Block(n); - } else { - // Since cache_write changes the block, we need to update the buffer it writes - Array reads, writes; - // New reads info. - for (const BufferRegion& buf_region : block->reads) { - if (buf_region->buffer.same_as(info_->write_buffer)) { + explicit ReindexCacheReadRewriter(const StmtSRef& scope_sref, ReindexCacheStageInfo* info) + : CacheReadRewriter(scope_sref, info) { + new_indices_ = info->indices; + update_access_regions = [&](Array reads) { + Array new_reads; + for (const BufferRegion& buf_region : reads) { + if (buf_region->buffer.same_as(info_->read_buffer)) { Array region; for (const PrimExpr index : new_indices_) { region.push_back(Range::FromMinExtent(index, Integer(1))); } - reads.push_back(BufferRegion(info_->read_buffer, region)); + new_reads.push_back(BufferRegion(info_->write_buffer, region)); } else { - reads.push_back(buf_region); + new_reads.push_back(buf_region); } } - // New writes info, same as above. - for (const BufferRegion& buf_region : block->writes) { - if (buf_region->buffer.same_as(info_->write_buffer)) { + return new_reads; + }; + update_match_buffers = [&](const Array match_buffers) { + Array new_match_buffers; + for (const MatchBufferRegion& match_buffer_region : match_buffers) { + BufferRegion source = match_buffer_region->source; + if (source->buffer.same_as(info_->read_buffer)) { Array region; for (const PrimExpr index : new_indices_) { region.push_back(Range::FromMinExtent(index, Integer(1))); } - writes.push_back(BufferRegion(info_->read_buffer, region)); + new_match_buffers.push_back(MatchBufferRegion(match_buffer_region->buffer, + BufferRegion(info_->write_buffer, region))); } else { - writes.push_back(buf_region); + new_match_buffers.push_back(match_buffer_region); } } - - // NOTE(Zihao): do not process match buffers for now. - if (!writes.same_as(block->writes) || !reads.same_as(block->reads)) { - ObjectPtr n = make_object(*stmt.as()); - n->writes = std::move(writes); - n->reads = std::move(reads); - stmt = Block(n); - } - } - // Remove atomic flag - ObjectPtr n = make_object(*stmt.as()); - n->annotations.erase("atomic"); - stmt = Block(n); - info_->block_reuse.Set(old_stmt, stmt); - return std::move(stmt); - } - - PrimExpr VisitExpr_(const VarNode* op) final { - if (op == info_->write_buffer->data.get()) { - return info_->read_buffer->data; - } - return GetRef(op); - } - - Stmt VisitStmt_(const BufferStoreNode* store) final { - BufferStore stmt = Downcast(StmtMutator::VisitStmt_(store)); - if (stmt->buffer.same_as(info_->write_buffer)) { - auto n = CopyOnWrite(stmt.get()); - n->buffer = info_->read_buffer; - n->indices = new_indices_; - return Stmt(n); - } else { - return std::move(stmt); - } + return new_match_buffers; + }; } PrimExpr VisitExpr_(const BufferLoadNode* load) final { - if (load->buffer.same_as(info_->write_buffer)) { + if (load->buffer.same_as(info_->read_buffer) && current_block_consumes) { ObjectPtr n = make_object(*load); - n->buffer = info_->read_buffer; + n->buffer = info_->write_buffer; n->indices = new_indices_; return PrimExpr(n); } return ExprMutator::VisitExpr_(load); } - /*! \brief The parent scope of the insertion. */ - const StmtSRef& scope_sref_; - /*! \brief The parent scope of the insertion. */ - const StmtSRef& writer_block_sref_; - /*! \brief The info for inserting cache stage. */ - CacheStageInfo* info_; /*! \brief The indices to use for new buffer. */ Array new_indices_; - /*! \brief Whether the current node is under the given block. */ - bool under_writer_block_{false}; }; +class ReindexCacheWriteRewriter; + /*! \brief Mutator for CacheWrite */ class CacheWriteRewriter : public StmtExprMutator { public: @@ -1170,7 +954,14 @@ class CacheWriteRewriter : public StmtExprMutator { private: explicit CacheWriteRewriter(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, CacheStageInfo* info) - : scope_sref_(scope_sref), writer_block_sref_(writer_block_sref), info_(info) {} + : scope_sref_(scope_sref), writer_block_sref_(writer_block_sref), info_(info) { + update_access_regions = [&](Array regions) { + return ReplaceBuffer(regions, info_->write_buffer, info_->read_buffer); + }; + update_match_buffers = [&](Array match_buffers) { + return ReplaceBuffer(match_buffers, info_->write_buffer, info_->read_buffer); + }; + } Stmt VisitStmt_(const ForNode* loop) final { Stmt stmt = StmtMutator::VisitStmt_(loop); @@ -1184,7 +975,7 @@ class CacheWriteRewriter : public StmtExprMutator { return stmt; } - Stmt VisitStmt_(const BlockNode* block) final { + Stmt VisitStmt_(const BlockNode* block) override { Block old_stmt = GetRef(block); // Check if this block is one of the specified cache consumers. @@ -1237,10 +1028,9 @@ class CacheWriteRewriter : public StmtExprMutator { } } else { // Since cache_write changes the block, we need to update the buffer it writes - auto writes = ReplaceBuffer(block->writes, info_->write_buffer, info_->read_buffer); - auto reads = ReplaceBuffer(block->reads, info_->write_buffer, info_->read_buffer); - auto match_buffers = - ReplaceBuffer(block->match_buffers, info_->write_buffer, info_->read_buffer); + auto writes = update_access_regions(block->writes); + auto reads = update_access_regions(block->reads); + auto match_buffers = update_match_buffers(block->match_buffers); if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { ObjectPtr n = make_object(*stmt.as()); @@ -1254,7 +1044,7 @@ class CacheWriteRewriter : public StmtExprMutator { return std::move(stmt); } - Stmt VisitStmt_(const BufferStoreNode* store) final { + Stmt VisitStmt_(const BufferStoreNode* store) override { BufferStore stmt = Downcast(StmtMutator::VisitStmt_(store)); if (stmt->buffer.same_as(info_->write_buffer)) { auto n = CopyOnWrite(stmt.get()); @@ -1265,7 +1055,7 @@ class CacheWriteRewriter : public StmtExprMutator { } } - PrimExpr VisitExpr_(const BufferLoadNode* load) final { + PrimExpr VisitExpr_(const BufferLoadNode* load) override { if (load->buffer.same_as(info_->write_buffer)) { ObjectPtr n = make_object(*load); n->buffer = info_->read_buffer; @@ -1298,6 +1088,93 @@ class CacheWriteRewriter : public StmtExprMutator { CacheStageInfo* info_; /*! \brief Whether the current node is under the given block. */ bool under_writer_block_{false}; + /*! \brief function to update read/write region of block being cache write.*/ + std::function(Array)> update_access_regions; + /*! \brief function to update match buffers of block being cache write.*/ + std::function(Array)> update_match_buffers; + + friend ReindexCacheWriteRewriter; +}; + +/*! \brief Mutator for ReindexCacheWrite. */ +class ReindexCacheWriteRewriter : public CacheWriteRewriter { + public: + /*! + * \brief Rewrite the AST and add a cache_write stage with the information provided. + * \param scope_sref The parent scope of this mutation. + * \param writer_block_sref The only writer block in the scope. + * \param info The cache stage information. + * \return The new AST rooting at the original parent scope. + */ + static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, + ReindexCacheStageInfo* info) { + ReindexCacheWriteRewriter rewriter(scope_sref, writer_block_sref, info); + return rewriter(GetRef(scope_sref->stmt)); + } + + private: + explicit ReindexCacheWriteRewriter(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, + ReindexCacheStageInfo* info) + : CacheWriteRewriter(scope_sref, writer_block_sref, info) { + new_indices_ = info->indices; + update_access_regions = [&](Array reads) { + Array new_reads; + for (const BufferRegion& buf_region : reads) { + if (buf_region->buffer.same_as(info_->write_buffer)) { + Array region; + for (const PrimExpr index : new_indices_) { + region.push_back(Range::FromMinExtent(index, Integer(1))); + } + new_reads.push_back(BufferRegion(info_->read_buffer, region)); + } else { + new_reads.push_back(buf_region); + } + } + return new_reads; + }; + update_match_buffers = [&](const Array match_buffers) { + Array new_match_buffers; + for (const MatchBufferRegion& match_buffer_region : match_buffers) { + BufferRegion source = match_buffer_region->source; + if (source->buffer.same_as(info_->write_buffer)) { + Array region; + for (const PrimExpr index : new_indices_) { + region.push_back(Range::FromMinExtent(index, Integer(1))); + } + new_match_buffers.push_back(MatchBufferRegion(match_buffer_region->buffer, + BufferRegion(info_->read_buffer, region))); + } else { + new_match_buffers.push_back(match_buffer_region); + } + } + return new_match_buffers; + }; + } + + Stmt VisitStmt_(const BufferStoreNode* store) final { + BufferStore stmt = Downcast(StmtMutator::VisitStmt_(store)); + if (stmt->buffer.same_as(info_->write_buffer)) { + auto n = CopyOnWrite(stmt.get()); + n->buffer = info_->read_buffer; + n->indices = new_indices_; + return Stmt(n); + } else { + return std::move(stmt); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* load) final { + if (load->buffer.same_as(info_->write_buffer)) { + ObjectPtr n = make_object(*load); + n->buffer = info_->read_buffer; + n->indices = new_indices_; + return PrimExpr(n); + } + return ExprMutator::VisitExpr_(load); + } + + /*! \brief The indices to use for new buffer. */ + Array new_indices_; }; /*! @@ -1725,6 +1602,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu return result_block_sref; } +/*! \brief A visitor that collects variables appeared in expressions, stored in touched filed.*/ class VarCollector : public ExprVisitor { public: VarCollector() {} @@ -1744,8 +1622,148 @@ Array GetLoopsUnderScope(const StmtSRef& block_sref, const StmtSRef& t return {result.rbegin(), result.rend()}; } -StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, - const String& storage_scope, Array dim_order, +/*! \brief The schedule error that block iter vars appears in old buffer and new + * allocated cache buffer does not match. + */ +class ReindexCacheReadWriteNotMatchError : public ScheduleError { + public: + ReindexCacheReadWriteNotMatchError(IRModule mod, Block block, Var var, + Array old_indices, Array new_indices, + bool is_cache_read, bool appears_in_old) + : mod_(std::move(mod)), block_(std::move(block)), var_(std::move(var)) { + primitive_name_ = is_cache_read ? "reindex_cache_read" : "reindex_cache_write"; + if (appears_in_old) { + appears_indices_ = std::move(old_indices); + other_indices_ = std::move(new_indices); + } else { + appears_indices_ = std::move(new_indices); + other_indices_ = std::move(old_indices); + } + } + String FastErrorString() const final { + return "ScheduleError: the block itervars appeared in lhs and rhs of reindex cache stage do " + "not match."; + } + + String DetailRenderTemplate() const final { + std::stringstream s; + s << "Error when applying " << primitive_name_ << " on block {0}, the block itervar " << var_ + << " appears in " << appears_indices_ << ", but not in " << other_indices_ << "."; + return String(s.str()); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod_; + String primitive_name_; + Block block_; + Var var_; + Array appears_indices_; + Array other_indices_; +}; + +/*! + * \brief Update ReindexCacheStageInfo and create new cache buffer, used in + * both ReindexCacheRead and ReindexCacheWrite. + * \param info Pointer to ReindexCacheStageInfo + * \param mod The IRModule. + * \param block_sref The StmtSRef to the block we are working on. + * \param storage_scope The storage scope of cache buffer (e.g. "shared"/"local"). + * \param index_map The user defined indices. + * \param blok The block we are working on. + * \param realize The BlockRealize this block belongs to. + * \param old_buffer The buffer whose buffer access need to be rewriten. + * \param cache_region The old buffer access region. + */ +template +void CollectReindexCacheStageInfoAndCreateBuffer( + ReindexCacheStageInfo* info, const IRModule& mod, const StmtSRef& block_sref, + const String& storage_scope, const IndexMap& index_map, const Block& block, + const BlockRealize& realize, const Buffer& old_buffer, const BufferRegion& cache_region) { + Array block_iter_vars, block_shape; + for (const IterVar& iter_var : block->iter_vars) { + block_iter_vars.push_back(iter_var); + block_shape.push_back(iter_var->dom->extent); + } + Array new_indices = index_map->MapIndices(block_iter_vars); + Array new_shape = index_map->MapShape(block_shape); + info->indices = new_indices; + + // Step 5. Update CacheTouchedInfo + VarCollector collector_old; + Array old_indices; + for (const Range& range : cache_region->region) { + collector_old(range->min); + old_indices.push_back(range->min); + } + + arith::Analyzer analyzer; + + VarCollector collector_new; + for (const PrimExpr& idx : new_indices) { + collector_new(idx); + } + + VarCollector collector_iter_values; + for (size_t i = 0; i < block->iter_vars.size(); ++i) { + const IterVar& block_iter_var = block->iter_vars[i]; + const PrimExpr& block_iter_value = realize->iter_values[i]; + bool appears_in_new = collector_new.touched.count(block_iter_var->var.get()); + bool appears_in_old = collector_old.touched.count(block_iter_var->var.get()); + if (appears_in_new != appears_in_old) { + throw ReindexCacheReadWriteNotMatchError(mod, block, block_iter_var->var, old_indices, + new_indices, is_cache_read, appears_in_old); + } + if (appears_in_new) { + info->block_iter_vars.push_back(block_iter_var); + info->block_iter_values.push_back(block_iter_value); + collector_iter_values(block_iter_value); + } + } + + for (const StmtSRef& loop_sref : GetLoopsUnderScope(block_sref, info->loc_sref)) { + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); + if (collector_iter_values.touched.count(loop->loop_var.get())) { + info->loop_vars.push_back(loop->loop_var); + info->loop_ranges.push_back(Range::FromMinExtent(loop->min, loop->extent)); + } + } + + // Create new buffer + ObjectPtr new_buffer = make_object(*old_buffer.get()); + ObjectPtr new_var = make_object(*old_buffer->data.get()); + const auto* ptr_type = TVM_TYPE_AS(old_buffer->data->type_annotation, PointerTypeNode); + new_var->type_annotation = PointerType(ptr_type->element_type, storage_scope); + new_buffer->data = Var(new_var->name_hint + "_" + storage_scope, new_var->type_annotation); + new_buffer->name = old_buffer->name + "_" + storage_scope; + new_buffer->shape = new_shape; + + if (is_cache_read) { + info->write_buffer = Buffer(new_buffer); + info->alloc = info->write_buffer; + } else { + info->read_buffer = Buffer(new_buffer); + info->alloc = info->read_buffer; + } +} + +/*! \brief Check whether given cache_region is a single point access. */ +template +void CheckSinglePoint(ScheduleState self, const Block& block, const BufferRegion& cache_region) { + bool single_point = true; + for (const Range& range : cache_region->region) { + const auto* ext_int = range->extent.as(); + if (!ext_int || ext_int->value != 1) { + single_point = false; + } + } + if (!single_point) { + throw NotSinglePointAccess(self->mod, block, cache_region, is_cache_read); + } +} + +StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, + const String& storage_scope, const IndexMap& index_map, const Array consumer_blocks) { /*! * Check: @@ -1762,14 +1780,13 @@ StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int re CheckStorageScope(self, storage_scope); // Step 1. Check index, getting the target buffer and the parent scope - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - Buffer read_buffer = - GetNthAccessBuffer(self, GetRef(block), read_buffer_index, BufferIndexType::kRead); + Block block = GetRef(TVM_SREF_TO_BLOCK(block_sref)); + BlockRealize realize = GetBlockRealize(self, block_sref); + Buffer read_buffer = GetNthAccessBuffer(self, block, read_buffer_index, BufferIndexType::kRead); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); - const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); // Step 2. Create CacheStageInfo - CacheStageInfo info; + ReindexCacheStageInfo info; info.read_buffer = read_buffer; // info.consumer_blocks indicates which buffers should consume the cache. @@ -1788,7 +1805,6 @@ StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int re if (Optional _write_block_sref = GetOnlyWriteBlock(self, scope_sref, read_buffer)) { // Case 1. The buffer is written inside the block. StmtSRef write_block_sref = _write_block_sref.value(); - const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref); // Find the producing region StmtSRef parent_sref = GetRef(write_block_sref->parent); // Detect insert position @@ -1799,71 +1815,20 @@ StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int re info.loc_pos = 0; } - // Step 4. Create CacheTouchedInfo - ReverseCacheTouchedInfo touched_info; + // Step 4. Check whether cache region is a single point. + CheckSinglePoint(self, block, cache_region); - // Step 5. Update CacheTouchedInfo - touched_info.read = true; - VarCollector collector; - Array new_shape; - for (const Range& range : cache_region->region) { - collector(range->min); - } - BlockRealize realize = GetBlockRealize(self, block_sref); - std::unordered_set dim_order_set; - for (const Integer& idx : dim_order) { - dim_order_set.insert(idx->value); - } - for (size_t idx = 0; idx < block->iter_vars.size(); ++idx) { - const IterVar& block_var = block->iter_vars[idx]; - if (collector.touched.count(block_var->var.get())) { - if (dim_order_set.empty()) { - // no user defined dim order. - dim_order.push_back(idx); - } else { - // user provide dim order, check whether it's valid. - CHECK(dim_order_set.count(idx)) - << "Block iter_var " << block_var - << " used in the block, but doesn't appear in user-specified dim order array."; - } - } - } - - for (size_t i = 0; i < dim_order.size(); ++i) { - int idx = dim_order[i]->value; - const IterVar& block_var = block->iter_vars[idx]; - touched_info.block_vars.push_back(block_var); - touched_info.iter_values.push_back(realize->iter_values[idx]); - new_shape.push_back(block_var->dom->min + block_var->dom->extent); - collector(touched_info.iter_values.back()); - } - - for (const StmtSRef& loop_sref : GetLoopsUnderScope(block_sref, info.loc_sref)) { - const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); - if (collector.touched.count(loop->loop_var.get())) { - touched_info.loop_vars.push_back(loop->loop_var); - touched_info.loop_ranges.push_back(Range::FromMinExtent(loop->min, loop->extent)); - } - } - - // Create write buffer. - ObjectPtr new_buffer = make_object(*read_buffer.get()); - ObjectPtr new_var = make_object(*read_buffer->data.get()); - const auto* ptr_type = TVM_TYPE_AS(read_buffer->data->type_annotation, PointerTypeNode); - new_var->type_annotation = PointerType(ptr_type->element_type, storage_scope); - new_buffer->data = Var(new_var->name_hint + "_" + storage_scope, new_var->type_annotation); - new_buffer->name = read_buffer->name + "_" + storage_scope; - new_buffer->shape = new_shape; - - info.write_buffer = Buffer(new_buffer); - info.alloc = info.write_buffer; + // Step 5. Collect ReindexCacheStageInfo and create new buffer. + CollectReindexCacheStageInfoAndCreateBuffer( + &info, self->mod, block_sref, storage_scope, index_map, block, realize, read_buffer, + cache_region); // Step 6. Making new cache stage block and rewrite readers. - Block cache_read_stage = MakeReverseCacheStage(/*cache_region=*/cache_region, - /*touched_info=*/&touched_info, /*info=*/&info, - /*storage_scope=*/storage_scope); - Stmt new_scope = ReverseCacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info, - /*touched_info=*/&touched_info); + Block cache_read_stage = + MakeReindexCacheStage(/*cache_region=*/cache_region, + /*info=*/&info, + /*storage_scope=*/storage_scope); + Stmt new_scope = ReindexCacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info); // Step 7. Replacing and updating flags. self->Replace(scope_sref, new_scope, info.block_reuse); @@ -1875,8 +1840,8 @@ StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int re return result_block_sref; } -StmtSRef ReverseCacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, - const String& storage_scope, Array dim_order, +StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, + const String& storage_scope, const IndexMap& index_map, const Array consumer_blocks) { /*! * Check: @@ -1893,13 +1858,14 @@ StmtSRef ReverseCacheWrite(ScheduleState self, const StmtSRef& block_sref, int w CheckStorageScope(self, storage_scope); // Step 1. Checking index, getting the target buffer and the parent scope - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + Block block = GetRef(TVM_SREF_TO_BLOCK(block_sref)); + BlockRealize realize = GetBlockRealize(self, block_sref); Buffer write_buffer = - GetNthAccessBuffer(self, GetRef(block), write_buffer_index, BufferIndexType::kWrite); + GetNthAccessBuffer(self, block, write_buffer_index, BufferIndexType::kWrite); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); // Step 2. Creating CacheStageInfo - CacheStageInfo info; + ReindexCacheStageInfo info; info.write_buffer = write_buffer; // info.consumer_blocks indicates which buffers should consume the cache. @@ -1921,74 +1887,23 @@ StmtSRef ReverseCacheWrite(ScheduleState self, const StmtSRef& block_sref, int w CacheLocDetector::Detect(self, block_sref, scope_sref, &info); BufferRegion cache_region = maybe_region.value(); - // Step 5. Create CacheTouchedInfo - ReverseCacheTouchedInfo touched_info; - - // Step 6. Update CacheTouchedInfo - touched_info.read = false; - VarCollector collector; - Array new_shape; - for (const Range& range : cache_region->region) { - collector(range->min); - } - BlockRealize realize = GetBlockRealize(self, block_sref); - std::unordered_set dim_order_set; - for (const Integer& idx : dim_order) { - dim_order_set.insert(idx->value); - } - for (size_t idx = 0; idx < block->iter_vars.size(); ++idx) { - const IterVar& block_var = block->iter_vars[idx]; - if (collector.touched.count(block_var->var.get())) { - if (dim_order_set.empty()) { - // no user defined dim order. - dim_order.push_back(idx); - } else { - // user provide dim order, check whether it's valid. - CHECK(dim_order_set.count(idx)) - << "Block iter_var " << block_var - << " used in the block, but doesn't appear in user-specified dim order array."; - } - } - } + CollectReindexCacheStageInfoAndCreateBuffer( + &info, self->mod, block_sref, storage_scope, index_map, block, realize, write_buffer, + cache_region); - for (size_t i = 0; i < dim_order.size(); ++i) { - int idx = dim_order[i]->value; - const IterVar& block_var = block->iter_vars[idx]; - touched_info.block_vars.push_back(block_var); - touched_info.iter_values.push_back(realize->iter_values[idx]); - new_shape.push_back(block_var->dom->min + block_var->dom->extent); - collector(touched_info.iter_values.back()); - } + // Step 5. Check whether cache region is a single point. + CheckSinglePoint(self, block, cache_region); - for (const StmtSRef& loop_sref : GetLoopsUnderScope(block_sref, info.loc_sref)) { - const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); - if (collector.touched.count(loop->loop_var.get())) { - touched_info.loop_vars.push_back(loop->loop_var); - touched_info.loop_ranges.push_back(Range::FromMinExtent(loop->min, loop->extent)); - } - } - - // Create write buffer. - ObjectPtr new_buffer = make_object(*write_buffer.get()); - ObjectPtr new_var = make_object(*write_buffer->data.get()); - const auto* ptr_type = TVM_TYPE_AS(write_buffer->data->type_annotation, PointerTypeNode); - new_var->type_annotation = PointerType(ptr_type->element_type, storage_scope); - new_buffer->data = Var(new_var->name_hint + "_" + storage_scope, new_var->type_annotation); - new_buffer->name = write_buffer->name + "_" + storage_scope; - new_buffer->shape = new_shape; - - info.read_buffer = Buffer(new_buffer); - info.alloc = info.read_buffer; - - // Step 7. Making new cache stage block and rewrite readers. - Block cache_write_stage = MakeReverseCacheStage(/*cache_region=*/cache_region, - /*touched_info=*/&touched_info, /*info=*/&info, - /*storage_scope=*/storage_scope); - Stmt new_scope = ReverseCacheWriteRewriter::Rewrite( + // Step 6. Making new cache stage block and rewrite readers. + Block cache_write_stage = + MakeReindexCacheStage(/*cache_region=*/cache_region, + /*info=*/&info, + /*storage_scope=*/storage_scope); + Stmt new_scope = ReindexCacheWriteRewriter::Rewrite( /*scope_sref=*/scope_sref, - /*writer_block_sref=*/block_sref, /*info=*/&info, /*touched_info=*/&touched_info); + /*writer_block_sref=*/block_sref, /*info=*/&info); - // Step 8. Replacing and updating flags. + // Step 7. Replacing and updating flags. self->Replace(scope_sref, new_scope, info.block_reuse); StmtSRef result_block_sref = self->stmt2ref.at(cache_write_stage.get()); BlockInfo& block_info = self->block_info[result_block_sref]; @@ -2307,8 +2222,8 @@ struct ReIndexTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; -struct ReverseCacheReadTraits : public UnpackedInstTraits { - static constexpr const char* kName = "ReverseCacheRead"; +struct ReindexCacheReadTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReindexCacheRead"; static constexpr bool kIsPure = false; private: @@ -2317,22 +2232,20 @@ struct ReverseCacheReadTraits : public UnpackedInstTraits dim_order, + String storage_scope, IndexMap index_map, Array consumer_blocks) { - return sch->ReverseCacheRead(block, read_buffer_index->value, storage_scope, dim_order, + return sch->ReindexCacheRead(block, read_buffer_index->value, storage_scope, index_map, consumer_blocks); } static String UnpackedAsPython(Array outputs, String block, Integer read_buffer_index, - String storage_scope, Array dim_order, + String storage_scope, IndexMap index_map, Array consumer_blocks) { - PythonAPICall py("reverse_cache_read"); + PythonAPICall py("reindex_cache_read"); py.Input("block", block); py.Input("read_buffer_index", read_buffer_index->value); py.Input("storage_scope", storage_scope); - if (!dim_order.empty()) { - py.Input("dim_order", dim_order); - } + py.Input("index_map", index_map->ToPythonString()); if (!consumer_blocks.empty()) { py.Input("consumer_blocks", consumer_blocks); } @@ -2344,8 +2257,8 @@ struct ReverseCacheReadTraits : public UnpackedInstTraits { - static constexpr const char* kName = "ReverseCacheWrite"; +struct ReindexCacheWriteTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReindexCacheWrite"; static constexpr bool kIsPure = false; private: @@ -2354,23 +2267,21 @@ struct ReverseCacheWriteTraits : public UnpackedInstTraits dim_order, + String storage_scope, IndexMap index_map, Array consumer_blocks) { - return sch->ReverseCacheWrite(block, write_buffer_index->value, storage_scope, dim_order, + return sch->ReindexCacheWrite(block, write_buffer_index->value, storage_scope, index_map, consumer_blocks); } static String UnpackedAsPython(Array outputs, String block, Integer write_buffer_index, - String storage_scope, Array dim_order, + String storage_scope, IndexMap index_map, Array consumer_blocks) { - PythonAPICall py("reverse_cache_write"); + PythonAPICall py("reindex_cache_write"); py.Input("block", block); py.Input("write_buffer_index", write_buffer_index->value); py.Input("storage_scope", storage_scope); - if (!dim_order.empty()) { - py.Input("dim_order", dim_order); - } - if (!dim_order.empty()) { + py.Input("index_map", index_map->ToPythonString()); + if (!consumer_blocks.empty()) { py.Input("consumer_blocks", consumer_blocks); } py.SingleOutput(outputs); @@ -2385,8 +2296,8 @@ TVM_REGISTER_INST_KIND_TRAITS(CacheReadTraits); TVM_REGISTER_INST_KIND_TRAITS(CacheWriteTraits); TVM_REGISTER_INST_KIND_TRAITS(CacheInplaceTraits); TVM_REGISTER_INST_KIND_TRAITS(ReIndexTraits); -TVM_REGISTER_INST_KIND_TRAITS(ReverseCacheReadTraits); -TVM_REGISTER_INST_KIND_TRAITS(ReverseCacheWriteTraits); +TVM_REGISTER_INST_KIND_TRAITS(ReindexCacheReadTraits); +TVM_REGISTER_INST_KIND_TRAITS(ReindexCacheWriteTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 6b47397c7fc3..354500b0fa76 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -179,10 +179,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead") .set_body_method(&ScheduleNode::CacheRead); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") .set_body_method(&ScheduleNode::CacheWrite); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseCacheRead") - .set_body_method(&ScheduleNode::ReverseCacheRead); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseCacheWrite") - .set_body_method(&ScheduleNode::ReverseCacheWrite); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReindexCacheRead") + .set_body_method(&ScheduleNode::ReindexCacheRead); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReindexCacheWrite") + .set_body_method(&ScheduleNode::ReindexCacheWrite); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheInplace") .set_body_method(&ScheduleNode::CacheInplace); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheIndex") diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index df98cbebdbb9..84b2e00ab5d1 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -309,34 +309,34 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer return result; } -BlockRV TracedScheduleNode::ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, Array dim_order, +BlockRV TracedScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope, const IndexMap& index_map, Array consumer_blocks) { - BlockRV result = ConcreteScheduleNode::ReverseCacheRead( - block_rv, read_buffer_index, storage_scope, dim_order, consumer_blocks); + BlockRV result = ConcreteScheduleNode::ReindexCacheRead( + block_rv, read_buffer_index, storage_scope, index_map, consumer_blocks); - static const InstructionKind& kind = InstructionKind::Get("ReverseCacheRead"); + static const InstructionKind& kind = InstructionKind::Get("ReindexCacheRead"); trace_->Append( /*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{block_rv}, - /*attrs=*/{Integer(read_buffer_index), storage_scope, dim_order, consumer_blocks}, + /*attrs=*/{Integer(read_buffer_index), storage_scope, index_map, consumer_blocks}, /*outputs=*/{result})); return result; } -BlockRV TracedScheduleNode::ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, Array dim_order, +BlockRV TracedScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope, const IndexMap& index_map, Array consumer_blocks) { - BlockRV result = ConcreteScheduleNode::ReverseCacheWrite( - block_rv, write_buffer_index, storage_scope, dim_order, consumer_blocks); + BlockRV result = ConcreteScheduleNode::ReindexCacheWrite( + block_rv, write_buffer_index, storage_scope, index_map, consumer_blocks); - static const InstructionKind& kind = InstructionKind::Get("ReverseCacheWrite"); + static const InstructionKind& kind = InstructionKind::Get("ReindexCacheWrite"); trace_->Append( /*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{block_rv}, - /*attrs=*/{Integer(write_buffer_index), storage_scope, dim_order, consumer_blocks}, + /*attrs=*/{Integer(write_buffer_index), storage_scope, index_map, consumer_blocks}, /*outputs=*/{result})); return result; } diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index d524e25fcb99..6afb60ad7035 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -76,11 +76,11 @@ class TracedScheduleNode : public ConcreteScheduleNode { const Array consumer_blocks = {}) final; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, const Array consumer_blocks = {}) final; - BlockRV ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, Array dim_order = {}, + BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope, const IndexMap& index_map, Array consumer_blocks = {}) final; - BlockRV ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, Array dim_order = {}, + BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope, const IndexMap& index_map, Array consumer_blocks = {}) final; Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) final; From 7d5f1312ec2aedb152b1a3fb1489a11df12fb015 Mon Sep 17 00:00:00 2001 From: Zihao Date: Wed, 1 Mar 2023 01:43:13 -0800 Subject: [PATCH 04/13] pass test --- python/tvm/contrib/nvcc.py | 5 +- python/tvm/tir/schedule/schedule.py | 46 +++++++--- .../schedule/primitive/cache_read_write.cc | 39 ++++----- src/tir/schedule/traced_schedule.cc | 11 +-- .../test_tir_schedule_cache_read_write.py | 85 +++++++++++++++++++ 5 files changed, 148 insertions(+), 38 deletions(-) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 5a104be9966d..8835e2b25f4c 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -99,8 +99,9 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target # However, we cannot use TVM_CXX_COMPILER_PATH because the runtime env. # Because it is hard to do runtime compiler detection, we require nvcc is configured # correctly by default. - # if cxx_compiler_path != "": - # cmd += ["-ccbin", cxx_compiler_path] + cxx_compiler_path = os.environ.get("CUDAHOSTCXX", "") + if cxx_compiler_path != "": + cmd += ["-ccbin", cxx_compiler_path] proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 44c5ca4fb4ca..1d4554f65785 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1115,7 +1115,7 @@ def cache_write( block: Union[BlockRV, str], write_buffer_index: Union[int, str, Buffer], storage_scope: str, - consumer_blocks=None, + consumer_blocks: Optional[List[Union[BlockRV, str]]] = None, ) -> BlockRV: """Create a block that reads a buffer region into a write cache. It requires: @@ -1205,13 +1205,19 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: @type_checked def reindex_cache_read( - self, block: BlockRV, read_buffer_index: int, storage_scope: str, + self, + block: Union[BlockRV, str], + read_buffer_index: int, + storage_scope: str, index_map: Union[IndexMap, Callable], - consumer_blocks=None, + consumer_blocks: Optional[List[Union[BlockRV, str]]] = None, ) -> BlockRV: - """Create a block that reads a buffer region into a read cache, with user customized - indices specified by index map. - The read region of the buffer to read in the block must be a single point. + """Create a block that reads a buffer region into a read cache using customized + indices specified by index map. The read region of the buffer must be a single point. + + The cache stage block follows the original order of loops and block itervars in the block. + If a block itervar does not appear in the buffer access region, + it and its corresponding loop variables will be omitted. Parameters ---------- @@ -1277,6 +1283,11 @@ def after_reindex_cache_read(a: T.handle, b: T.handle) -> None: """ if consumer_blocks is None: consumer_blocks = [] + + # Convert any string block names into Block RVs. + consumer_blocks = [self._normalize_block_arg(b) for b in consumer_blocks] + block = self._normalize_block_arg(block) + if callable(index_map): index_map = IndexMap.from_func(index_map) return _ffi_api.ScheduleReindexCacheRead( # type: ignore # pylint: disable=no-member @@ -1285,17 +1296,23 @@ def after_reindex_cache_read(a: T.handle, b: T.handle) -> None: @type_checked def reindex_cache_write( - self, block: BlockRV, write_buffer_index: int, storage_scope: str, + self, + block: Union[BlockRV, str], + write_buffer_index: int, + storage_scope: str, index_map: Union[Callable, IndexMap], - consumer_blocks=None, + consumer_blocks: Optional[List[Union[BlockRV, str]]] = None, ) -> BlockRV: - """Create a block that reads a buffer region into a write cache, with user customized - indices specified by index map. - The write region of the buffer to write in the block must be a single point. + r"""Create a block that reads a buffer region into a write cache using customized + indices specified by index map. The write region of the buffer must be a single point. + + The cache stage block follows the original order of loops and block itervars in the block. + If a block itervar does not appear in the buffer access region, + it and its corresponding loop variables will be omitted. Parameters ---------- - block : BlockRV + block : Union[BlockRV, str] The consumer block of the target buffer. write_buffer_index: int The index of the buffer in block's write region. @@ -1357,6 +1374,11 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: """ if consumer_blocks is None: consumer_blocks = [] + + # Convert any string block names into Block RVs. + consumer_blocks = [self._normalize_block_arg(b) for b in consumer_blocks] + block = self._normalize_block_arg(block) + if callable(index_map): index_map = IndexMap.from_func(index_map) return _ffi_api.ScheduleReindexCacheWrite( # type: ignore # pylint: disable=no-member diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 0c42b7f09844..c6db28f9099f 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -158,7 +158,7 @@ Block MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStageI Map var_map; for (size_t i = 0; i < info->loop_vars.size(); ++i) { Var original_var = info->loop_vars[i]; - Var loop_var("ax" + std::to_string(i), original_var.dtype()); + Var loop_var(original_var->name_hint, original_var.dtype()); var_map.Set(original_var, loop_var); loop_vars.push_back(loop_var); } @@ -167,7 +167,7 @@ Block MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStageI PrimExpr original_iter_value = info->block_iter_values[i]; IterVar block_var = IterVar( /*dom=*/original_block_var->dom, - /*var=*/Var("v" + std::to_string(i), original_block_var->var.dtype()), + /*var=*/Var(original_block_var->var->name_hint, original_block_var->var.dtype()), /*IterVarType=*/kDataPar); var_map.Set(original_block_var->var, block_var->var); block_vars.push_back(block_var); @@ -506,7 +506,8 @@ class CacheLocDetector : public StmtVisitor { * \brief Detect the insertion position of the cache stage, and write the position into the * CacheStageInfo * \param self The state of the schedule - * \param block_sref The sref of the unique writer block of the buffer being applied cache_read or cache_write + * \param block_sref The sref of the unique writer block of the buffer being applied cache_read or + * cache_write * \param scope_sref The sref of the scope block of the cached block * \param info The cache stage info. */ @@ -2227,20 +2228,20 @@ struct ReindexCacheReadTraits : public UnpackedInstTraits consumer_blocks) { + static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, IndexMap index_map, + Array consumer_blocks, Integer read_buffer_index, + String storage_scope) { return sch->ReindexCacheRead(block, read_buffer_index->value, storage_scope, index_map, consumer_blocks); } - static String UnpackedAsPython(Array outputs, String block, Integer read_buffer_index, - String storage_scope, IndexMap index_map, - Array consumer_blocks) { + static String UnpackedAsPython(Array outputs, String block, IndexMap index_map, + Array consumer_blocks, Integer read_buffer_index, + String storage_scope) { PythonAPICall py("reindex_cache_read"); py.Input("block", block); py.Input("read_buffer_index", read_buffer_index->value); @@ -2262,20 +2263,20 @@ struct ReindexCacheWriteTraits : public UnpackedInstTraits consumer_blocks) { + static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, IndexMap index_map, + Array consumer_blocks, Integer write_buffer_index, + String storage_scope) { return sch->ReindexCacheWrite(block, write_buffer_index->value, storage_scope, index_map, consumer_blocks); } - static String UnpackedAsPython(Array outputs, String block, Integer write_buffer_index, - String storage_scope, IndexMap index_map, - Array consumer_blocks) { + static String UnpackedAsPython(Array outputs, String block, IndexMap index_map, + Array consumer_blocks, Integer write_buffer_index, + String storage_scope) { PythonAPICall py("reindex_cache_write"); py.Input("block", block); py.Input("write_buffer_index", write_buffer_index->value); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 84b2e00ab5d1..21210011f963 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -319,14 +319,15 @@ BlockRV TracedScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_b trace_->Append( /*inst=*/Instruction( /*kind=*/kind, - /*inputs=*/{block_rv}, - /*attrs=*/{Integer(read_buffer_index), storage_scope, index_map, consumer_blocks}, + /*inputs=*/{block_rv, index_map, consumer_blocks}, + /*attrs=*/{Integer(read_buffer_index), storage_scope}, /*outputs=*/{result})); return result; } BlockRV TracedScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, const IndexMap& index_map, + const String& storage_scope, + const IndexMap& index_map, Array consumer_blocks) { BlockRV result = ConcreteScheduleNode::ReindexCacheWrite( block_rv, write_buffer_index, storage_scope, index_map, consumer_blocks); @@ -335,8 +336,8 @@ BlockRV TracedScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write trace_->Append( /*inst=*/Instruction( /*kind=*/kind, - /*inputs=*/{block_rv}, - /*attrs=*/{Integer(write_buffer_index), storage_scope, index_map, consumer_blocks}, + /*inputs=*/{block_rv, index_map, consumer_blocks}, + /*attrs=*/{Integer(write_buffer_index), storage_scope}, /*outputs=*/{result})); return result; } diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index be91505f3d15..4f507628ebc6 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -59,6 +59,58 @@ def elementwise_shape_int64(a: T.handle, c: T.handle) -> None: C[vi, vj] = B[vi, vj] + 1.0 +@T.prim_func +def elementwise_reindex_cache_read( + A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32") +): + B = T.alloc_buffer((128, 128)) + B_shared = T.alloc_buffer((128, 64, 2), scope="shared") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i, j in T.grid(128, 128): + with T.block("B_shared"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(B_shared[vj, vi // 2, vi % 2]) + B_shared[vj, vi // 2, vi % 2] = B[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B_shared[vj, vi // 2, vi % 2]) + T.writes(C[vi, vj]) + C[vi, vj] = B_shared[vj, vi // 2, vi % 2] + T.float32(1) + + +@T.prim_func +def elementwise_reindex_cache_write( + A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32") +): + B = T.alloc_buffer((128, 128)) + B_shared = T.alloc_buffer((128, 128), scope="shared") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B_shared[vj, vi]) + B_shared[vj, vi] = A[vi, vj] * T.float32(2) + for i, j in T.grid(128, 128): + with T.block("B_shared"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B_shared[vj, vi]) + T.writes(B[vi, vj]) + B[vi, vj] = B_shared[vj, vi] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + T.float32(1) + + @T.prim_func def func_nested_seq(b: T.handle, c: T.handle) -> None: A = T.alloc_buffer((128, 128)) @@ -1336,5 +1388,38 @@ def test_cache_write_allocate_const(): verify_trace_roundtrip(sch=sch, mod=cache_write_allocate_const) +def test_reindex_cache_read(): + sch = tir.Schedule(elementwise, debug_mask="all") + sch.reindex_cache_read("C", 0, "shared", lambda i, j: (j, i // 2, i % 2)) + tvm.ir.assert_structural_equal(elementwise_reindex_cache_read, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + sch = tir.Schedule(elementwise, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.reindex_cache_read( + "C", + 0, + "shared", + lambda i, j: j * 2, + ) + + +def test_reindex_cache_write(): + sch = tir.Schedule(elementwise, debug_mask="all") + sch.reindex_cache_write("B", 0, "shared", lambda i, j: (j, i)) + print(sch.mod["main"].show()) + tvm.ir.assert_structural_equal(elementwise_reindex_cache_write, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + sch = tir.Schedule(elementwise, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.reindex_cache_write( + "B", + 0, + "shared", + lambda i, j: i, + ) + + if __name__ == "__main__": tvm.testing.main() From 54bc4277eee3fa544a3237daed74240a880c5a0e Mon Sep 17 00:00:00 2001 From: Zihao Date: Wed, 1 Mar 2023 02:22:09 -0800 Subject: [PATCH 05/13] add some extra docstring --- python/tvm/tir/schedule/schedule.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 1d4554f65785..da9507ea2ce6 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1216,8 +1216,9 @@ def reindex_cache_read( indices specified by index map. The read region of the buffer must be a single point. The cache stage block follows the original order of loops and block itervars in the block. - If a block itervar does not appear in the buffer access region, - it and its corresponding loop variables will be omitted. + If a block itervar does not appear in the buffer access region, it and its corresponding loop + variables will be omitted. User can then use `transform_block_layout` primitive to reorder + the block itervars and surrounding loops of the cache read/write block. Parameters ---------- @@ -1307,8 +1308,9 @@ def reindex_cache_write( indices specified by index map. The write region of the buffer must be a single point. The cache stage block follows the original order of loops and block itervars in the block. - If a block itervar does not appear in the buffer access region, - it and its corresponding loop variables will be omitted. + If a block itervar does not appear in the buffer access region, it and its corresponding loop + variables will be omitted. User can then use `transform_block_layout` primitive to reorder + the block itervars and surrounding loops of the cache read/write block. Parameters ---------- From 4eae481cc7a8092e6b8efbda754fab3e075948db Mon Sep 17 00:00:00 2001 From: Zihao Date: Wed, 1 Mar 2023 06:05:20 -0800 Subject: [PATCH 06/13] fix lint --- CMakeLists.txt | 4 ++-- python/tvm/contrib/nvcc.py | 5 ++--- python/tvm/tir/schedule/schedule.py | 12 ++++++------ 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 818e8b50addb..c2b3d0732b16 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,7 +24,7 @@ endif() # You can create a config.cmake at build folder # and add set(OPTION VALUE) to override these build options. # Alernatively, use cmake -DOPTION=VALUE through command-line. -tvm_option(USE_CUDA "Build with CUDA" OFF) +tvm_option(USE_CUDA "Build with CUDA" ON) tvm_option(USE_OPENCL "Build with OpenCL" OFF) tvm_option(USE_OPENCL_ENABLE_HOST_PTR "Enable OpenCL memory object access to host" OFF) tvm_option(USE_OPENCL_GTEST "Path to OpenCL specific gtest version for runtime cpp tests." /path/to/opencl/gtest) @@ -61,7 +61,7 @@ tvm_option(USE_RTTI "Build with RTTI" ON) tvm_option(USE_MSVC_MT "Build with MT" OFF) tvm_option(USE_MICRO "Build with Micro TVM support" OFF) tvm_option(INSTALL_DEV "Install compiler infrastructure" OFF) -tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF) +tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." ON) tvm_option(USE_TF_TVMDSOOP "Build with TensorFlow TVMDSOOp" OFF) tvm_option(USE_PT_TVMDSOOP "Build with PyTorch TVMDSOOp" OFF) tvm_option(USE_FALLBACK_STL_MAP "Use TVM's POD compatible Map" OFF) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 8835e2b25f4c..5a104be9966d 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -99,9 +99,8 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target # However, we cannot use TVM_CXX_COMPILER_PATH because the runtime env. # Because it is hard to do runtime compiler detection, we require nvcc is configured # correctly by default. - cxx_compiler_path = os.environ.get("CUDAHOSTCXX", "") - if cxx_compiler_path != "": - cmd += ["-ccbin", cxx_compiler_path] + # if cxx_compiler_path != "": + # cmd += ["-ccbin", cxx_compiler_path] proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index da9507ea2ce6..44782e650c50 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1216,9 +1216,9 @@ def reindex_cache_read( indices specified by index map. The read region of the buffer must be a single point. The cache stage block follows the original order of loops and block itervars in the block. - If a block itervar does not appear in the buffer access region, it and its corresponding loop - variables will be omitted. User can then use `transform_block_layout` primitive to reorder - the block itervars and surrounding loops of the cache read/write block. + If a block itervar does not appear in the buffer access region, it and its corresponding + loop variables will be omitted. User can then use `transform_block_layout` primitive to + reorder the block itervars and surrounding loops of the cache read/write block. Parameters ---------- @@ -1308,9 +1308,9 @@ def reindex_cache_write( indices specified by index map. The write region of the buffer must be a single point. The cache stage block follows the original order of loops and block itervars in the block. - If a block itervar does not appear in the buffer access region, it and its corresponding loop - variables will be omitted. User can then use `transform_block_layout` primitive to reorder - the block itervars and surrounding loops of the cache read/write block. + If a block itervar does not appear in the buffer access region, it and its corresponding + loop variables will be omitted. User can then use `transform_block_layout` primitive to + reorder the block itervars and surrounding loops of the cache read/write block. Parameters ---------- From 8698b2108191a55acae8b87b676ba70e1e3ace63 Mon Sep 17 00:00:00 2001 From: Zihao Date: Wed, 1 Mar 2023 06:28:19 -0800 Subject: [PATCH 07/13] fix cpplint --- CMakeLists.txt | 4 ++-- src/tir/schedule/primitive/cache_read_write.cc | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c2b3d0732b16..818e8b50addb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,7 +24,7 @@ endif() # You can create a config.cmake at build folder # and add set(OPTION VALUE) to override these build options. # Alernatively, use cmake -DOPTION=VALUE through command-line. -tvm_option(USE_CUDA "Build with CUDA" ON) +tvm_option(USE_CUDA "Build with CUDA" OFF) tvm_option(USE_OPENCL "Build with OpenCL" OFF) tvm_option(USE_OPENCL_ENABLE_HOST_PTR "Enable OpenCL memory object access to host" OFF) tvm_option(USE_OPENCL_GTEST "Path to OpenCL specific gtest version for runtime cpp tests." /path/to/opencl/gtest) @@ -61,7 +61,7 @@ tvm_option(USE_RTTI "Build with RTTI" ON) tvm_option(USE_MSVC_MT "Build with MT" OFF) tvm_option(USE_MICRO "Build with Micro TVM support" OFF) tvm_option(INSTALL_DEV "Install compiler infrastructure" OFF) -tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." ON) +tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF) tvm_option(USE_TF_TVMDSOOP "Build with TensorFlow TVMDSOOp" OFF) tvm_option(USE_PT_TVMDSOOP "Build with PyTorch TVMDSOOp" OFF) tvm_option(USE_FALLBACK_STL_MAP "Use TVM's POD compatible Map" OFF) diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index c6db28f9099f..b5f8462bbd05 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -1603,7 +1603,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu return result_block_sref; } -/*! \brief A visitor that collects variables appeared in expressions, stored in touched filed.*/ +/*! \brief A visitor that collects variables appeared in expressions, stored in `touched` field.*/ class VarCollector : public ExprVisitor { public: VarCollector() {} From 7bca3c2ca9fb144244f1a6b5968f5911d47a956b Mon Sep 17 00:00:00 2001 From: Zihao Date: Wed, 1 Mar 2023 07:01:28 -0800 Subject: [PATCH 08/13] header lint --- include/tvm/tir/schedule/schedule.h | 12 ++++++------ src/tir/schedule/primitive.h | 18 ++++++++---------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index d3329a6a0339..f8e0efd11d71 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -413,9 +413,9 @@ class ScheduleNode : public runtime::Object { * \param block_rv The consumer block of the target buffer. * \param read_buffer_index The index of the buffer in block's read region. * \param storage_scope The target storage scope. - * \param index_map User defined indices to access allocated cache buffer, maps from block iter vars. - * \param consumer_blocks An optional list of consumers to read from cache directly. - * \return The cache stage block. + * \param index_map User defined indices to access allocated cache buffer, maps from block iter + * vars. \param consumer_blocks An optional list of consumers to read from cache directly. \return + * The cache stage block. */ virtual BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope, const IndexMap& index_map, @@ -428,9 +428,9 @@ class ScheduleNode : public runtime::Object { * \param block_rv The producer of the buffer * \param write_buffer_index The index of the buffer in block's write region * \param storage_scope The target storage scope - * \param index_map User defined indices to access allocated cache buffer, maps from block iter vars. - * \param consumer_blocks An optional list of consumers to read from cache directly. - * \return The cache stage block. + * \param index_map User defined indices to access allocated cache buffer, maps from block iter + * vars. \param consumer_blocks An optional list of consumers to read from cache directly. \return + * The cache stage block. */ virtual BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, const IndexMap& index_map, diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 1ad75c313ce0..77654fc6660d 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -278,14 +278,13 @@ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int * \param block_sref The consumer block of the target buffer. * \param read_buffer_index The index of the buffer in block's read region. * \param storage_scope The target storage scope. - * \param index_map User defined indices to access allocated cache buffer, maps from block iter vars. - * \param consumer_blocks Array of blocks that consume the cache. - * \return The cache stage block. + * \param index_map User defined indices to access allocated cache buffer, maps from block iter + * vars. \param consumer_blocks Array of blocks that consume the cache. \return The cache stage + * block. */ TVM_DLL StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, const String& storage_scope, - const IndexMap& index_map, - Array consumer_blocks = {}); + const IndexMap& index_map, Array consumer_blocks = {}); /*! * \brief Create a block that writes a buffer region into a write cache. It requires: * 1) There is only one block that writes the target buffer. @@ -295,14 +294,13 @@ TVM_DLL StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref * \param block_sref The producer of the buffer * \param write_buffer_index The index of the buffer in block's write region * \param storage_scope The target storage scope - * \param index_map User defined indices to access allocated cache buffer, maps from block iter vars. - * \param consumer_blocks Array of blocks that consume the cache. - * \return The cache stage block. + * \param index_map User defined indices to access allocated cache buffer, maps from block iter + * vars. \param consumer_blocks Array of blocks that consume the cache. \return The cache stage + * block. */ TVM_DLL StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, const String& storage_scope, - const IndexMap& index_map, - Array consumer_blocks = {}); + const IndexMap& index_map, Array consumer_blocks = {}); /*! *! From 733befeed07c90eadce5d770fcb838e0a026260e Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 3 Mar 2023 00:42:08 -0800 Subject: [PATCH 09/13] only one consumer --- include/tvm/tir/schedule/schedule.h | 14 ++-- python/tvm/tir/schedule/schedule.py | 36 +++++---- src/tir/schedule/concrete_schedule.cc | 20 +---- src/tir/schedule/concrete_schedule.h | 6 +- src/tir/schedule/primitive.h | 12 +-- .../schedule/primitive/cache_read_write.cc | 74 ++++++------------- src/tir/schedule/traced_schedule.cc | 14 ++-- src/tir/schedule/traced_schedule.h | 6 +- .../test_tir_schedule_cache_read_write.py | 57 +++++++++++++- 9 files changed, 128 insertions(+), 111 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index f8e0efd11d71..76f2ccf67739 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -414,12 +414,11 @@ class ScheduleNode : public runtime::Object { * \param read_buffer_index The index of the buffer in block's read region. * \param storage_scope The target storage scope. * \param index_map User defined indices to access allocated cache buffer, maps from block iter - * vars. \param consumer_blocks An optional list of consumers to read from cache directly. \return - * The cache stage block. + * vars. + * \return The cache stage block. */ virtual BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, const IndexMap& index_map, - Array consumer_blocks) = 0; + const String& storage_scope, const IndexMap& index_map) = 0; /*! * \brief Create a block that writes a buffer region into a write cache. It requires: * 1) There is only one block who writes the target buffer. @@ -429,12 +428,11 @@ class ScheduleNode : public runtime::Object { * \param write_buffer_index The index of the buffer in block's write region * \param storage_scope The target storage scope * \param index_map User defined indices to access allocated cache buffer, maps from block iter - * vars. \param consumer_blocks An optional list of consumers to read from cache directly. \return - * The cache stage block. + * vars. + * \return The cache stage block. */ virtual BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, const IndexMap& index_map, - Array consumer_blocks) = 0; + const String& storage_scope, const IndexMap& index_map) = 0; /*! * \brief Create 2 blocks that read&write a buffer region into a read/write cache. * It requires the the target block both read & write the target buffer. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 44782e650c50..7431afe10bc9 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1210,7 +1210,6 @@ def reindex_cache_read( read_buffer_index: int, storage_scope: str, index_map: Union[IndexMap, Callable], - consumer_blocks: Optional[List[Union[BlockRV, str]]] = None, ) -> BlockRV: """Create a block that reads a buffer region into a read cache using customized indices specified by index map. The read region of the buffer must be a single point. @@ -1220,6 +1219,9 @@ def reindex_cache_read( loop variables will be omitted. User can then use `transform_block_layout` primitive to reorder the block itervars and surrounding loops of the cache read/write block. + Unlike `cache_read`, `reindex_cache_read` only supports single consumer, please use + `cache_read` when there are multiple consumers. + Parameters ---------- block : BlockRV @@ -1230,9 +1232,6 @@ def reindex_cache_read( The target storage scope. index_map: Union[IndexMap, Callable] User defined indices to access allocated cache buffer, maps from block iter vars. - consumer_blocks: Optional[List[Union[BlockRV, str]]] - An optional list of consumers that should read directly from the cache. - If not specified, all consumers will read from the original buffer. Returns ------- @@ -1281,18 +1280,21 @@ def after_reindex_cache_read(a: T.handle, b: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A_local[vj, vi] * 2.0 + See Also + -------- + reindex_cache_write + transform_block_layout + transform_layout + cache_read + reindex """ - if consumer_blocks is None: - consumer_blocks = [] - # Convert any string block names into Block RVs. - consumer_blocks = [self._normalize_block_arg(b) for b in consumer_blocks] block = self._normalize_block_arg(block) if callable(index_map): index_map = IndexMap.from_func(index_map) return _ffi_api.ScheduleReindexCacheRead( # type: ignore # pylint: disable=no-member - self, block, read_buffer_index, storage_scope, index_map, consumer_blocks + self, block, read_buffer_index, storage_scope, index_map ) @type_checked @@ -1312,6 +1314,9 @@ def reindex_cache_write( loop variables will be omitted. User can then use `transform_block_layout` primitive to reorder the block itervars and surrounding loops of the cache read/write block. + Unlike `cache_write`, `reindex_cache_write` only supports single consumer, please use + `cache_write` when there are multiple consumers. + Parameters ---------- block : Union[BlockRV, str] @@ -1373,18 +1378,21 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = B_local[vi % 2, vi // 2, vj] + See Also + -------- + reindex_cache_read + transform_block_layout + transform_layout + cache_write + reindex """ - if consumer_blocks is None: - consumer_blocks = [] - # Convert any string block names into Block RVs. - consumer_blocks = [self._normalize_block_arg(b) for b in consumer_blocks] block = self._normalize_block_arg(block) if callable(index_map): index_map = IndexMap.from_func(index_map) return _ffi_api.ScheduleReindexCacheWrite( # type: ignore # pylint: disable=no-member - self, block, write_buffer_index, storage_scope, index_map, consumer_blocks + self, block, write_buffer_index, storage_scope, index_map ) @type_checked diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 89c454178f6d..3b5dce41da1a 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -570,17 +570,11 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff BlockRV ConcreteScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope, - const IndexMap& index_map, - const Array consumer_blocks) { + const IndexMap& index_map) { StmtSRef result{nullptr}; - // Create a new array of SRefs from the consumer block list. - Array consumer_block_refs = {}; - for (BlockRV block : consumer_blocks) { - consumer_block_refs.push_back(this->GetSRef(block)); - } TVM_TIR_SCHEDULE_BEGIN(); result = tir::ReindexCacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope, - index_map, consumer_block_refs); + index_map); TVM_TIR_SCHEDULE_END("reverse-cache-read", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); @@ -588,17 +582,11 @@ BlockRV ConcreteScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read BlockRV ConcreteScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, - const IndexMap& index_map, - const Array consumer_blocks) { + const IndexMap& index_map) { StmtSRef result{nullptr}; - // Create a new array of SRefs from the consumer block list. - Array consumer_block_refs = {}; - for (BlockRV block : consumer_blocks) { - consumer_block_refs.push_back(this->GetSRef(block)); - } TVM_TIR_SCHEDULE_BEGIN(); result = tir::ReindexCacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, - storage_scope, index_map, consumer_block_refs); + storage_scope, index_map); TVM_TIR_SCHEDULE_END("reverse-cache-write", this->error_render_level_); this->state_->DebugVerify(); return CreateRV(result); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index bb4c712f1e28..e72fac956b6e 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -117,11 +117,9 @@ class ConcreteScheduleNode : public ScheduleNode { BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, const Array consumer_blocks = {}) override; BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, const IndexMap& index_map, - Array consumer_blocks = {}) override; + const String& storage_scope, const IndexMap& index_map) override; BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, const IndexMap& index_map, - Array consumer_blocks = {}) override; + const String& storage_scope, const IndexMap& index_map) override; Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) override; Array CacheIndex(const BlockRV& block_rv, const String& storage_scope, diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 77654fc6660d..5b44ccc7968c 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -279,12 +279,12 @@ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int * \param read_buffer_index The index of the buffer in block's read region. * \param storage_scope The target storage scope. * \param index_map User defined indices to access allocated cache buffer, maps from block iter - * vars. \param consumer_blocks Array of blocks that consume the cache. \return The cache stage - * block. + * vars. + * \return The cache stage block. */ TVM_DLL StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, const String& storage_scope, - const IndexMap& index_map, Array consumer_blocks = {}); + const IndexMap& index_map); /*! * \brief Create a block that writes a buffer region into a write cache. It requires: * 1) There is only one block that writes the target buffer. @@ -295,12 +295,12 @@ TVM_DLL StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref * \param write_buffer_index The index of the buffer in block's write region * \param storage_scope The target storage scope * \param index_map User defined indices to access allocated cache buffer, maps from block iter - * vars. \param consumer_blocks Array of blocks that consume the cache. \return The cache stage - * block. + * vars. + * \return The cache stage block. */ TVM_DLL StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, const String& storage_scope, - const IndexMap& index_map, Array consumer_blocks = {}); + const IndexMap& index_map); /*! *! diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index b5f8462bbd05..2185f99bb723 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -120,11 +120,11 @@ class NotSinglePointAccess : public ScheduleError { } String DetailRenderTemplate() const final { - std::stringstream s; - s << "The buffer region " << cache_region_ - << " accessed inside block {0} is not a single point, which violates" - << " the prerequisite of " << primitive_name_ << " primitive."; - return String(s.str()); + std::ostringstream os; + os << "The buffer region " << cache_region_ + << " accessed inside block {0} is not a single point, which violates" + << " the prerequisite of " << primitive_name_ << " primitive."; + return String(os.str()); } IRModule mod() const final { return mod_; } @@ -985,12 +985,13 @@ class CacheWriteRewriter : public StmtExprMutator { const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref); Block consumer_block = GetRef(consumer_node); if (old_stmt.same_as(consumer_block)) { - Array reads = - ReplaceBuffer(block->reads, info_->write_buffer, info_->read_buffer); - Array match_buffers = - ReplaceBuffer(block->match_buffers, info_->write_buffer, info_->read_buffer); - if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { + Array writes = update_access_regions(block->writes); + Array reads = update_access_regions(block->reads); + Array match_buffers = update_match_buffers(block->match_buffers); + if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || + !match_buffers.same_as(block->match_buffers)) { auto n = CopyOnWrite(block); + n->writes = std::move(writes); n->reads = std::move(reads); n->match_buffers = std::move(match_buffers); n->body = VisitStmt(block->body); @@ -1764,8 +1765,7 @@ void CheckSinglePoint(ScheduleState self, const Block& block, const BufferRegion } StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, - const String& storage_scope, const IndexMap& index_map, - const Array consumer_blocks) { + const String& storage_scope, const IndexMap& index_map) { /*! * Check: * - The index is in the array of block reading region @@ -1789,14 +1789,7 @@ StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int re // Step 2. Create CacheStageInfo ReindexCacheStageInfo info; info.read_buffer = read_buffer; - - // info.consumer_blocks indicates which buffers should consume the cache. - for (auto consumer : consumer_blocks) { - info.consumer_blocks.insert(consumer); - for (auto child : tir::GetChildBlocks(self, consumer)) { - info.consumer_blocks.insert(child); - } - } + info.consumer_blocks.insert(block_sref); // Step 3. Update cache stage info. Optional maybe_region = GetBufferRegionFromBuffer(block->reads, read_buffer); @@ -1842,8 +1835,7 @@ StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int re } StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, - const String& storage_scope, const IndexMap& index_map, - const Array consumer_blocks) { + const String& storage_scope, const IndexMap& index_map) { /*! * Check: * - The index is in the array of block reading region @@ -1868,14 +1860,8 @@ StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int w // Step 2. Creating CacheStageInfo ReindexCacheStageInfo info; info.write_buffer = write_buffer; - - // info.consumer_blocks indicates which buffers should consume the cache. - for (auto consumer : consumer_blocks) { - info.consumer_blocks.insert(consumer); - for (auto child : tir::GetChildBlocks(self, consumer)) { - info.consumer_blocks.insert(child); - } - } + LOG(INFO) << block->name_hint; + info.consumer_blocks.insert(block_sref); // Step 3. Check the only writer block. ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get()); @@ -2228,28 +2214,22 @@ struct ReindexCacheReadTraits : public UnpackedInstTraits consumer_blocks, Integer read_buffer_index, - String storage_scope) { - return sch->ReindexCacheRead(block, read_buffer_index->value, storage_scope, index_map, - consumer_blocks); + Integer read_buffer_index, String storage_scope) { + return sch->ReindexCacheRead(block, read_buffer_index->value, storage_scope, index_map); } static String UnpackedAsPython(Array outputs, String block, IndexMap index_map, - Array consumer_blocks, Integer read_buffer_index, - String storage_scope) { + Integer read_buffer_index, String storage_scope) { PythonAPICall py("reindex_cache_read"); py.Input("block", block); py.Input("read_buffer_index", read_buffer_index->value); py.Input("storage_scope", storage_scope); py.Input("index_map", index_map->ToPythonString()); - if (!consumer_blocks.empty()) { - py.Input("consumer_blocks", consumer_blocks); - } py.SingleOutput(outputs); return py.Str(); } @@ -2263,28 +2243,22 @@ struct ReindexCacheWriteTraits : public UnpackedInstTraits consumer_blocks, Integer write_buffer_index, - String storage_scope) { - return sch->ReindexCacheWrite(block, write_buffer_index->value, storage_scope, index_map, - consumer_blocks); + Integer write_buffer_index, String storage_scope) { + return sch->ReindexCacheWrite(block, write_buffer_index->value, storage_scope, index_map); } static String UnpackedAsPython(Array outputs, String block, IndexMap index_map, - Array consumer_blocks, Integer write_buffer_index, - String storage_scope) { + Integer write_buffer_index, String storage_scope) { PythonAPICall py("reindex_cache_write"); py.Input("block", block); py.Input("write_buffer_index", write_buffer_index->value); py.Input("storage_scope", storage_scope); py.Input("index_map", index_map->ToPythonString()); - if (!consumer_blocks.empty()) { - py.Input("consumer_blocks", consumer_blocks); - } py.SingleOutput(outputs); return py.Str(); } diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 21210011f963..e454b82c5315 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -310,16 +310,15 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer } BlockRV TracedScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, const IndexMap& index_map, - Array consumer_blocks) { + const String& storage_scope, const IndexMap& index_map) { BlockRV result = ConcreteScheduleNode::ReindexCacheRead( - block_rv, read_buffer_index, storage_scope, index_map, consumer_blocks); + block_rv, read_buffer_index, storage_scope, index_map); static const InstructionKind& kind = InstructionKind::Get("ReindexCacheRead"); trace_->Append( /*inst=*/Instruction( /*kind=*/kind, - /*inputs=*/{block_rv, index_map, consumer_blocks}, + /*inputs=*/{block_rv, index_map}, /*attrs=*/{Integer(read_buffer_index), storage_scope}, /*outputs=*/{result})); return result; @@ -327,16 +326,15 @@ BlockRV TracedScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_b BlockRV TracedScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, - const IndexMap& index_map, - Array consumer_blocks) { + const IndexMap& index_map) { BlockRV result = ConcreteScheduleNode::ReindexCacheWrite( - block_rv, write_buffer_index, storage_scope, index_map, consumer_blocks); + block_rv, write_buffer_index, storage_scope, index_map); static const InstructionKind& kind = InstructionKind::Get("ReindexCacheWrite"); trace_->Append( /*inst=*/Instruction( /*kind=*/kind, - /*inputs=*/{block_rv, index_map, consumer_blocks}, + /*inputs=*/{block_rv, index_map}, /*attrs=*/{Integer(write_buffer_index), storage_scope}, /*outputs=*/{result})); return result; diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 6afb60ad7035..45b56f6b2689 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -77,11 +77,9 @@ class TracedScheduleNode : public ConcreteScheduleNode { BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, const Array consumer_blocks = {}) final; BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, const IndexMap& index_map, - Array consumer_blocks = {}) final; + const String& storage_scope, const IndexMap& index_map) final; BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope, const IndexMap& index_map, - Array consumer_blocks = {}) final; + const String& storage_scope, const IndexMap& index_map) final; Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) final; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index 4f507628ebc6..cf75768ec0e3 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -268,6 +268,39 @@ def func_multi_consumer() -> None: C[vi] = A[vi] +@T.prim_func +def reindex_cache_read_multi_consumer() -> None: + A = T.alloc_buffer((128,)) + B = T.alloc_buffer((128,)) + C = T.alloc_buffer((128,)) + A_shared = T.alloc_buffer((4, 32), scope="shared") + for i in range(8): + for j in range(16): + with T.block("A"): + vi = T.axis.spatial(128, i * 16 + j) + T.reads() + T.writes(A[vi]) + A[vi] = T.float32(1) + for j in range(16): + with T.block("A_shared"): + vi = T.axis.spatial(128, i * 16 + j) + T.reads(A[vi]) + T.writes(A_shared[vi // 32, vi % 32]) + A_shared[vi // 32, vi % 32] = A[vi] + for j in range(16): + with T.block("B"): + vi = T.axis.spatial(128, i * 16 + j) + T.reads(A_shared[vi // 32, vi % 32]) + T.writes(B[vi]) + B[vi] = A_shared[vi // 32, vi % 32] + T.float32(1) + for i in range(128): + with T.block("C"): + vi = T.axis.spatial(128, i) + T.reads(A[vi]) + T.writes(C[vi]) + C[vi] = A[vi] + + @T.prim_func def func_multi_producer() -> None: A = T.alloc_buffer((128)) @@ -1394,6 +1427,15 @@ def test_reindex_cache_read(): tvm.ir.assert_structural_equal(elementwise_reindex_cache_read, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) + +def test_reindex_cache_read_multi_consumer(): + sch = tir.Schedule(func_multi_consumer) + sch.reindex_cache_read("B", 0, "shared", lambda i: (i // 32, i % 32)) + tvm.ir.assert_structural_equal(reindex_cache_read_multi_consumer, sch.mod["main"]) + # NOTE(zihao): we do not verify trace roundtrip because of in set analysis issues. + + +def test_reindex_cache_read_fail_not_match(): sch = tir.Schedule(elementwise, debug_mask="all") with pytest.raises(tvm.tir.ScheduleError): sch.reindex_cache_read( @@ -1404,13 +1446,20 @@ def test_reindex_cache_read(): ) +def test_reindex_cache_read_faile_not_single_point(): + sch = tir.Schedule(access_under_scope, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.reindex_cache_read("scope", 0, "shared", lambda i, j: (i, j)) + + def test_reindex_cache_write(): sch = tir.Schedule(elementwise, debug_mask="all") sch.reindex_cache_write("B", 0, "shared", lambda i, j: (j, i)) - print(sch.mod["main"].show()) tvm.ir.assert_structural_equal(elementwise_reindex_cache_write, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=elementwise) + +def test_reindex_cache_write_fail_not_match(): sch = tir.Schedule(elementwise, debug_mask="all") with pytest.raises(tvm.tir.ScheduleError): sch.reindex_cache_write( @@ -1421,5 +1470,11 @@ def test_reindex_cache_write(): ) +def test_reindex_cache_write_fail_not_single_point(): + sch = tir.Schedule(access_under_scope, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.reindex_cache_write("scope", 0, "shared", lambda i, j: (i, j)) + + if __name__ == "__main__": tvm.testing.main() From db1d7928efc273a84d0b272d69bc9b09adf453a7 Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 3 Mar 2023 01:23:03 -0800 Subject: [PATCH 10/13] remove unused args --- python/tvm/tir/schedule/schedule.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 7431afe10bc9..e659983d89e1 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1304,7 +1304,6 @@ def reindex_cache_write( write_buffer_index: int, storage_scope: str, index_map: Union[Callable, IndexMap], - consumer_blocks: Optional[List[Union[BlockRV, str]]] = None, ) -> BlockRV: r"""Create a block that reads a buffer region into a write cache using customized indices specified by index map. The write region of the buffer must be a single point. From afb7e32826310e69e1aba2218dfa4d9c95011671 Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 3 Mar 2023 04:40:37 -0800 Subject: [PATCH 11/13] lint --- src/tir/schedule/traced_schedule.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index e454b82c5315..f19668bf83d2 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -310,9 +310,10 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer } BlockRV TracedScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope, const IndexMap& index_map) { - BlockRV result = ConcreteScheduleNode::ReindexCacheRead( - block_rv, read_buffer_index, storage_scope, index_map); + const String& storage_scope, + const IndexMap& index_map) { + BlockRV result = + ConcreteScheduleNode::ReindexCacheRead(block_rv, read_buffer_index, storage_scope, index_map); static const InstructionKind& kind = InstructionKind::Get("ReindexCacheRead"); trace_->Append( @@ -327,8 +328,8 @@ BlockRV TracedScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_b BlockRV TracedScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, const IndexMap& index_map) { - BlockRV result = ConcreteScheduleNode::ReindexCacheWrite( - block_rv, write_buffer_index, storage_scope, index_map); + BlockRV result = ConcreteScheduleNode::ReindexCacheWrite(block_rv, write_buffer_index, + storage_scope, index_map); static const InstructionKind& kind = InstructionKind::Get("ReindexCacheWrite"); trace_->Append( From 3f6900113f053a429f914037a953b8b722ab6582 Mon Sep 17 00:00:00 2001 From: Zihao Date: Tue, 7 Mar 2023 07:15:48 -0800 Subject: [PATCH 12/13] use VarDefUseAnalyzer --- .../schedule/primitive/cache_read_write.cc | 30 ++++++++----------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 2185f99bb723..9d714c8eda33 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -19,6 +19,7 @@ #include +#include "../../analysis/var_use_def_analysis.h" #include "../utils.h" namespace tvm { @@ -1205,7 +1206,9 @@ Buffer CreateReindexBuffer(const Buffer& buffer, const Array& block_ite return Buffer(new_buffer); } -/*! \brief The schedule error that the target is not a leaf block. */ +/*! + * \brief The schedule error that the target is not a leaf block. + */ class NotLeafBlockError : public ScheduleError { public: NotLeafBlockError(IRModule mod, Block block) : mod_(std::move(mod)), block_(std::move(block)) {} @@ -1604,16 +1607,6 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu return result_block_sref; } -/*! \brief A visitor that collects variables appeared in expressions, stored in `touched` field.*/ -class VarCollector : public ExprVisitor { - public: - VarCollector() {} - std::unordered_set touched; - - private: - void VisitExpr_(const VarNode* op) final { touched.insert(op); } -}; - Array GetLoopsUnderScope(const StmtSRef& block_sref, const StmtSRef& top_sref) { std::vector result; for (StmtSRefNode* parent = block_sref->parent; parent && parent->stmt->IsInstance(); @@ -1624,7 +1617,8 @@ Array GetLoopsUnderScope(const StmtSRef& block_sref, const StmtSRef& t return {result.rbegin(), result.rend()}; } -/*! \brief The schedule error that block iter vars appears in old buffer and new +/*! + * \brief The schedule error that block iter vars appears in old buffer and new * allocated cache buffer does not match. */ class ReindexCacheReadWriteNotMatchError : public ScheduleError { @@ -1692,7 +1686,7 @@ void CollectReindexCacheStageInfoAndCreateBuffer( info->indices = new_indices; // Step 5. Update CacheTouchedInfo - VarCollector collector_old; + VarUseDefAnalyzer collector_old; Array old_indices; for (const Range& range : cache_region->region) { collector_old(range->min); @@ -1701,17 +1695,17 @@ void CollectReindexCacheStageInfoAndCreateBuffer( arith::Analyzer analyzer; - VarCollector collector_new; + VarUseDefAnalyzer collector_new; for (const PrimExpr& idx : new_indices) { collector_new(idx); } - VarCollector collector_iter_values; + VarUseDefAnalyzer collector_iter_values; for (size_t i = 0; i < block->iter_vars.size(); ++i) { const IterVar& block_iter_var = block->iter_vars[i]; const PrimExpr& block_iter_value = realize->iter_values[i]; - bool appears_in_new = collector_new.touched.count(block_iter_var->var.get()); - bool appears_in_old = collector_old.touched.count(block_iter_var->var.get()); + bool appears_in_new = collector_new.use_count_.count(block_iter_var->var.get()); + bool appears_in_old = collector_old.use_count_.count(block_iter_var->var.get()); if (appears_in_new != appears_in_old) { throw ReindexCacheReadWriteNotMatchError(mod, block, block_iter_var->var, old_indices, new_indices, is_cache_read, appears_in_old); @@ -1725,7 +1719,7 @@ void CollectReindexCacheStageInfoAndCreateBuffer( for (const StmtSRef& loop_sref : GetLoopsUnderScope(block_sref, info->loc_sref)) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); - if (collector_iter_values.touched.count(loop->loop_var.get())) { + if (collector_iter_values.use_count_.count(loop->loop_var.get())) { info->loop_vars.push_back(loop->loop_var); info->loop_ranges.push_back(Range::FromMinExtent(loop->min, loop->extent)); } From c0c7a456922a529bf72c9fb2039b3f1a2de7f284 Mon Sep 17 00:00:00 2001 From: Zihao Date: Tue, 7 Mar 2023 07:23:34 -0800 Subject: [PATCH 13/13] fix --- src/tir/schedule/primitive/cache_read_write.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 9d714c8eda33..39e915ba961a 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -1686,7 +1686,7 @@ void CollectReindexCacheStageInfoAndCreateBuffer( info->indices = new_indices; // Step 5. Update CacheTouchedInfo - VarUseDefAnalyzer collector_old; + VarUseDefAnalyzer collector_old(/*defined_vars=*/{}); Array old_indices; for (const Range& range : cache_region->region) { collector_old(range->min); @@ -1695,12 +1695,12 @@ void CollectReindexCacheStageInfoAndCreateBuffer( arith::Analyzer analyzer; - VarUseDefAnalyzer collector_new; + VarUseDefAnalyzer collector_new(/*defined_vars=*/{}); for (const PrimExpr& idx : new_indices) { collector_new(idx); } - VarUseDefAnalyzer collector_iter_values; + VarUseDefAnalyzer collector_iter_values(/*defined_vars=*/{}); for (size_t i = 0; i < block->iter_vars.size(); ++i) { const IterVar& block_iter_var = block->iter_vars[i]; const PrimExpr& block_iter_value = realize->iter_values[i];