diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 7f2bdf6b4ebb..22febfdfedec 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -405,6 +405,34 @@ class ScheduleNode : public runtime::Object { virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, const Array consumer_blocks = {}) = 0; + /*! + * \brief Create a block that reads a buffer region into a read cache. It requires: + * 1) There is at most one block who writes the buffer in the scope. + * 2) The scope block have stage-pipeline property. + * Compared to cache read, the indices to access allocated cache buffer is customized by user. + * \param block_rv The consumer block of the target buffer. + * \param read_buffer_index The index of the buffer in block's read region. + * \param storage_scope The target storage scope. + * \param index_map User defined indices to access allocated cache buffer, maps from block iter + * vars. + * \return The cache stage block. + */ + virtual BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope, const IndexMap& index_map) = 0; + /*! + * \brief Create a block that writes a buffer region into a write cache. It requires: + * 1) There is only one block who writes the target buffer. + * 2) The scope block have stage-pipeline property. + * Compared to cache write, the indices to access allocated cache buffer is customized by user. + * \param block_rv The producer of the buffer + * \param write_buffer_index The index of the buffer in block's write region + * \param storage_scope The target storage scope + * \param index_map User defined indices to access allocated cache buffer, maps from block iter + * vars. + * \return The cache stage block. + */ + virtual BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope, const IndexMap& index_map) = 0; /*! * \brief Create 2 blocks that read&write a buffer region into a read/write cache. * It requires the the target block both read & write the target buffer. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index b63353bcb382..896e2fc48e72 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1115,7 +1115,7 @@ def cache_write( block: Union[BlockRV, str], write_buffer_index: Union[int, str, Buffer], storage_scope: str, - consumer_blocks=None, + consumer_blocks: Optional[List[Union[BlockRV, str]]] = None, ) -> BlockRV: """Create a block that reads a buffer region into a write cache. It requires: @@ -1203,6 +1203,197 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: self, block, write_buffer_index, storage_scope, consumer_blocks ) + @type_checked + def reindex_cache_read( + self, + block: Union[BlockRV, str], + read_buffer_index: int, + storage_scope: str, + index_map: Union[IndexMap, Callable], + ) -> BlockRV: + """Create a block that reads a buffer region into a read cache using customized + indices specified by index map. The read region of the buffer must be a single point. + + The cache stage block follows the original order of loops and block itervars in the block. + If a block itervar does not appear in the buffer access region, it and its corresponding + loop variables will be omitted. User can then use `transform_block_layout` primitive to + reorder the block itervars and surrounding loops of the cache read/write block. + + Unlike `cache_read`, `reindex_cache_read` only supports single consumer, please use + `cache_read` when there are multiple consumers. + + Parameters + ---------- + block : BlockRV + The consumer block of the target buffer. + read_buffer_index: int + The index of the buffer in block's read region. + storage_scope: str + The target storage scope. + index_map: Union[IndexMap, Callable] + User defined indices to access allocated cache buffer, maps from block iter vars. + + Returns + ------- + cached_block : BlockRV + The block of the cache stage + + Examples + -------- + Before reindex_cache_read, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_reindex_cache_read(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and reindex_cache_read: + + .. code-block:: python + + sch = tir.Schedule(before_cache_read) + block_b = sch.get_block("B") + sch.reindex_cache_read(block_b, 0, "local", lambda vi, vj: (vj, vi)) + print(sch.mod["main"].script()) + + After applying reindex_cache_read, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_reindex_cache_read(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + A_local = T.alloc_buffer((128, 128), scope="local") + for i, j in T.grid(128, 128): + with T.block("A_local"): + vi, vj = T.axis.remap("SS", [i, j]) + A_local[vj, vi] = A[vi, vj] + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A_local[vj, vi] * 2.0 + + See Also + -------- + reindex_cache_write + transform_block_layout + transform_layout + cache_read + reindex + """ + # Convert any string block names into Block RVs. + block = self._normalize_block_arg(block) + + if callable(index_map): + index_map = IndexMap.from_func(index_map) + return _ffi_api.ScheduleReindexCacheRead( # type: ignore # pylint: disable=no-member + self, block, read_buffer_index, storage_scope, index_map + ) + + @type_checked + def reindex_cache_write( + self, + block: Union[BlockRV, str], + write_buffer_index: int, + storage_scope: str, + index_map: Union[Callable, IndexMap], + ) -> BlockRV: + r"""Create a block that reads a buffer region into a write cache using customized + indices specified by index map. The write region of the buffer must be a single point. + + The cache stage block follows the original order of loops and block itervars in the block. + If a block itervar does not appear in the buffer access region, it and its corresponding + loop variables will be omitted. User can then use `transform_block_layout` primitive to + reorder the block itervars and surrounding loops of the cache read/write block. + + Unlike `cache_write`, `reindex_cache_write` only supports single consumer, please use + `cache_write` when there are multiple consumers. + + Parameters + ---------- + block : Union[BlockRV, str] + The consumer block of the target buffer. + write_buffer_index: int + The index of the buffer in block's write region. + storage_scope: str + The target storage scope. + index_map: Union[Callable, IndexMap] + User defined indices to access allocated cache buffer, maps from block iter vars. + consumer_blocks: Optional[List[Union[BlockRV, str]]] + An optional list of consumers that should read directly from the cache. + If not specified, all consumers will read from the original buffer. + + Returns + ------- + cached_block : BlockRV + The block of the cache stage + + Examples + -------- + Before reindex_cache_write, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_reindex_cache_write(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and reindex_cache_write: + + .. code-block:: python + + sch = tir.Schedule(before_cache_write) + block_b = sch.get_block("B") + sch.reindex_cache_write(block_b, 0, "local", lambda vi, vj: (vi // 2, vi % 2, vj)) + print(sch.mod["main"].script()) + + After applying reindex_cache_write, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_cache_write(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (64, 2, 128)) + B_local = T.alloc_buffer((128, 128), scope="local") + for i, j in T.grid(128, 128): + with T.block("A_local"): + vi, vj = T.axis.remap("SS", [i, j]) + B_local[vi % 2, vi // 2, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = B_local[vi % 2, vi // 2, vj] + + See Also + -------- + reindex_cache_read + transform_block_layout + transform_layout + cache_write + reindex + """ + # Convert any string block names into Block RVs. + block = self._normalize_block_arg(block) + + if callable(index_map): + index_map = IndexMap.from_func(index_map) + return _ffi_api.ScheduleReindexCacheWrite( # type: ignore # pylint: disable=no-member + self, block, write_buffer_index, storage_scope, index_map + ) + @type_checked def cache_inplace( self, @@ -1439,7 +1630,7 @@ def before_reindex( vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vj, vi] * 2.0 - Create the schedule and do transform_layout: + Create the schedule and do reindex: .. code-block:: python diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 8af39b24fdb8..5a9dab4854bd 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -568,6 +568,30 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff return CreateRV(result); } +BlockRV ConcreteScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope, + const IndexMap& index_map) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::ReindexCacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope, + index_map); + TVM_TIR_SCHEDULE_END("reverse-cache-read", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + +BlockRV ConcreteScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope, + const IndexMap& index_map) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::ReindexCacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, + storage_scope, index_map); + TVM_TIR_SCHEDULE_END("reverse-cache-write", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + Array ConcreteScheduleNode::CacheInplace(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) { Array results; diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 41168fb016f3..82ac9f913374 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -116,6 +116,10 @@ class ConcreteScheduleNode : public ScheduleNode { const Array consumer_blocks = {}) override; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, const Array consumer_blocks = {}) override; + BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope, const IndexMap& index_map) override; + BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope, const IndexMap& index_map) override; Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) override; Array CacheIndex(const BlockRV& block_rv, const String& storage_scope, diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 0b7a4f6280db..563864229a26 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -269,6 +269,39 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, const String& storage_scope, const Array consumer_blocks = {}); +/*! + * \brief Create a block that reads a buffer region into a read cache. It requires: + * 1) There is at most one block who writes the buffer in the scope. + * 2) The scope block have stage-pipeline property. + * Compared to cache read, the indices to access allocated cache buffer is customized by user. + * \param self The state of the schedule + * \param block_sref The consumer block of the target buffer. + * \param read_buffer_index The index of the buffer in block's read region. + * \param storage_scope The target storage scope. + * \param index_map User defined indices to access allocated cache buffer, maps from block iter + * vars. + * \return The cache stage block. + */ +TVM_DLL StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, + int read_buffer_index, const String& storage_scope, + const IndexMap& index_map); +/*! + * \brief Create a block that writes a buffer region into a write cache. It requires: + * 1) There is only one block that writes the target buffer. + * 2) The scope block have stage-pipeline property. + * Compared to cache write, the indices to access allocated cache buffer is customized by user. + * \param self The state of the schedule + * \param block_sref The producer of the buffer + * \param write_buffer_index The index of the buffer in block's write region + * \param storage_scope The target storage scope + * \param index_map User defined indices to access allocated cache buffer, maps from block iter + * vars. + * \return The cache stage block. + */ +TVM_DLL StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, + int write_buffer_index, const String& storage_scope, + const IndexMap& index_map); + /*! *! * \brief Create 2 blocks that read&write a buffer region into a read/write cache. diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index a2b45d407ddf..39e915ba961a 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -19,6 +19,7 @@ #include +#include "../../analysis/var_use_def_analysis.h" #include "../utils.h" namespace tvm { @@ -94,6 +95,132 @@ Optional GetBufferRegionFromBuffer(const Array& buff return res; } +struct ReindexCacheStageInfo : CacheStageInfo { + /* Indices used to access the allocated cache buffer. */ + Array indices; + /* Touched loop variable related information. */ + Array loop_vars; + Array loop_ranges; + /* Touched block variable related information. */ + Array block_iter_vars; + Array block_iter_values; +}; + +/* \brief The schedule error that accessed buffer region is not a single point for + * reindex_cache_read/write. */ +class NotSinglePointAccess : public ScheduleError { + public: + explicit NotSinglePointAccess(IRModule mod, Block block, BufferRegion cache_region, + bool is_cache_read) + : mod_(std::move(mod)), block_(std::move(block)), cache_region_(cache_region) { + primitive_name_ = is_cache_read ? "reindex_cache_read" : "reindex_cache_write"; + } + + String FastErrorString() const final { + return "ScheduleError: The buffer region accessed inside the block is not a single point."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The buffer region " << cache_region_ + << " accessed inside block {0} is not a single point, which violates" + << " the prerequisite of " << primitive_name_ << " primitive."; + return String(os.str()); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + BufferRegion cache_region_; + String primitive_name_; +}; + +/*! + * \brief Create a loop nest that represents reindex cache copy (reindex_cache_read / + * reindex_cache_write) from read buffer to write buffer. + * \param cache_region The cached copy region. + * \param info The cache stage information, which will be updated in the function. + * \param storage_scope The storage scope of the cached buffer (only used in naming here) + * \returns A block indicating the body of the loop nesting. + */ +template +Block MakeReindexCacheStage(const BufferRegion& cache_region, ReindexCacheStageInfo* info, + const String& storage_scope) { + // loop variables + std::vector loop_vars; + // block variables + Array block_vars; + // bindings in block realize + std::vector iter_values; + // Create loop vars and block vars' binding_value + Map var_map; + for (size_t i = 0; i < info->loop_vars.size(); ++i) { + Var original_var = info->loop_vars[i]; + Var loop_var(original_var->name_hint, original_var.dtype()); + var_map.Set(original_var, loop_var); + loop_vars.push_back(loop_var); + } + for (size_t i = 0; i < info->block_iter_vars.size(); ++i) { + IterVar original_block_var = info->block_iter_vars[i]; + PrimExpr original_iter_value = info->block_iter_values[i]; + IterVar block_var = IterVar( + /*dom=*/original_block_var->dom, + /*var=*/Var(original_block_var->var->name_hint, original_block_var->var.dtype()), + /*IterVarType=*/kDataPar); + var_map.Set(original_block_var->var, block_var->var); + block_vars.push_back(block_var); + iter_values.push_back(Substitute(original_iter_value, var_map)); + } + + // block access region for read/write buffers + Region read_access_region, write_access_region; + Array read_access_indices, write_access_indices; + // Compute read/write region and read/write access indices. + Array& old_indices = (is_cache_read) ? read_access_indices : write_access_indices; + Region& old_region = (is_cache_read) ? read_access_region : write_access_region; + for (const Range& range : cache_region->region) { + old_indices.push_back(Substitute(range->min, var_map)); + old_region.push_back(Range::FromMinExtent(old_indices.back(), Integer(1))); + } + Array& new_indices = (is_cache_read) ? write_access_indices : read_access_indices; + Region& new_region = (is_cache_read) ? write_access_region : read_access_region; + for (const PrimExpr& idx : info->indices) { + new_indices.push_back(Substitute((idx), var_map)); + new_region.push_back(Range::FromMinExtent(new_indices.back(), Integer(1))); + } + + // Create New Block + Block block( + /*iter_vars*/ std::move(block_vars), + /*reads=*/{BufferRegion(info->read_buffer, read_access_region)}, + /*writes=*/{BufferRegion(info->write_buffer, write_access_region)}, + /*name_hint*/ cache_region->buffer->name + "_" + storage_scope, + /*body=*/ + BufferStore(info->write_buffer, BufferLoad(info->read_buffer, read_access_indices), + write_access_indices), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*buf_doms=*/{}); + // Create Block Realize node + Stmt body = BlockRealize(/*values=*/iter_values, + /*predicate=*/const_true(), + /*block=*/block); + // Create surrounding loops + for (size_t i = loop_vars.size(); i >= 1; --i) { + body = For(/*loop_var=*/loop_vars[i - 1], + /*min=*/info->loop_ranges[i - 1]->min, + /*extent=*/info->loop_ranges[i - 1]->extent, + /*kind=*/ForKind::kSerial, + /*body=*/body); + } + info->cache_stage = std::move(body); + return block; +} + /*! * \brief Create a loop nest that represents cache copy (cache_read / cache_write) from read buffer * to write buffer. @@ -378,9 +505,12 @@ class CacheLocDetector : public StmtVisitor { public: /*! * \brief Detect the insertion position of the cache stage, and write the position into the - * CacheStageInfo \param self The state of the schedule \param block_sref The sref of the unique - * writer block of the buffer being applied cache_read or cache_write \param scope_sref The sref - * of the scope block of the cached block \param info The cache stage info. + * CacheStageInfo + * \param self The state of the schedule + * \param block_sref The sref of the unique writer block of the buffer being applied cache_read or + * cache_write + * \param scope_sref The sref of the scope block of the cached block + * \param info The cache stage info. */ template static void Detect(const ScheduleState& self, const StmtSRef& block_sref, @@ -433,8 +563,9 @@ class CacheLocDetector : public StmtVisitor { * \brief Constructor * \param self The state of the schedule * \param block_sref The sref of the unique writer block of the buffer being applied cache_read or - * cache_write \param scope_sref The sref of the scope block of the cached block \param - * related_blocks Producer blocks for cache_write, or consumer blocks for cache_read + * cache_write + * \param scope_sref The sref of the scope block of the cached block + * \param related_blocks Producer blocks for cache_write, or consumer blocks for cache_read */ CacheLocDetector(const ScheduleState self, const StmtSRef& block_sref, const StmtSRef& scope_sref, const std::vector& related_blocks) @@ -525,9 +656,11 @@ class CacheInplaceLocDetector : public StmtVisitor { public: /*! * \brief Detect the insertion position of the cache stage, and write the position into the - * CacheStageInfo \param self The state of the schedule \param block_sref The sref of the unique - * block of the buffer being applied cache_inplace \param scope_sref The sref - * of the scope block of the cached block \param info The cache stage info. + * CacheStageInfo + * \param self The state of the schedule + * \param block_sref The sref of the unique block of the buffer being applied cache_inplace + * \param scope_sref The sref of the scope block of the cached block + * \param info The cache stage info. */ static void Detect(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_sref, CacheStageInfo* info) { @@ -606,6 +739,8 @@ class CacheInplaceLocDetector : public StmtVisitor { int loc_pos_{-1}; }; +class ReindexCacheReadRewriter; + /*! \brief Mutator for CacheRead. */ class CacheReadRewriter : public StmtExprMutator { public: @@ -622,7 +757,14 @@ class CacheReadRewriter : public StmtExprMutator { private: explicit CacheReadRewriter(const StmtSRef& scope_sref, CacheStageInfo* info) - : scope_sref_(scope_sref), info_(info) {} + : scope_sref_(scope_sref), info_(info) { + update_access_regions = [&](Array regions) { + return ReplaceBuffer(std::move(regions), info_->read_buffer, info_->write_buffer); + }; + update_match_buffers = [&](Array match_buffers) { + return ReplaceBuffer(std::move(match_buffers), info_->read_buffer, info_->write_buffer); + }; + } Stmt VisitStmt_(const ForNode* loop) final { Stmt stmt = StmtMutator::VisitStmt_(loop); @@ -636,7 +778,7 @@ class CacheReadRewriter : public StmtExprMutator { return stmt; } - Stmt VisitStmt_(const BlockNode* block) final { + Stmt VisitStmt_(const BlockNode* block) override { Block old_stmt = GetRef(block); // Check if this block is one of the specified consumers. // If no consumer blocks are specified, all blocks should be considered consumers. @@ -678,10 +820,8 @@ class CacheReadRewriter : public StmtExprMutator { // Otherwise, update read regions and match_buffers // Only make this change if the block is one of the specified consumers. if (is_consumer) { - Array reads = - ReplaceBuffer(block->reads, info_->read_buffer, info_->write_buffer); - Array match_buffers = - ReplaceBuffer(block->match_buffers, info_->read_buffer, info_->write_buffer); + Array reads = update_access_regions(block->reads); + Array match_buffers = update_match_buffers(block->match_buffers); if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { ObjectPtr n = make_object(*stmt.as()); n->reads = std::move(reads); @@ -694,7 +834,7 @@ class CacheReadRewriter : public StmtExprMutator { return std::move(stmt); } - PrimExpr VisitExpr_(const BufferLoadNode* load) final { + PrimExpr VisitExpr_(const BufferLoadNode* load) override { if (load->buffer.same_as(info_->read_buffer) && current_block_consumes) { ObjectPtr n = make_object(*load); n->buffer = info_->write_buffer; @@ -721,8 +861,82 @@ class CacheReadRewriter : public StmtExprMutator { CacheStageInfo* info_; /*! \brief Whether the most recently visited block is a specified consumer. */ bool current_block_consumes; + /*! \brief function to update read/write region of block being cache read.*/ + std::function(Array)> update_access_regions; + /*! \brief function to update match buffers of block being cache read.*/ + std::function(Array)> update_match_buffers; + + friend ReindexCacheReadRewriter; }; +/*! \brief Mutator for ReindexCacheRead. */ +class ReindexCacheReadRewriter : public CacheReadRewriter { + public: + /*! + * \brief Rewrite the AST and add a cache_read stage with the information provided. + * \param scope_sref The parent scope of this mutation. + * \param info The cache stage information. + * \return The new AST rooting at the original parent scope. + */ + static Stmt Rewrite(const StmtSRef& scope_sref, ReindexCacheStageInfo* info) { + ReindexCacheReadRewriter rewriter(scope_sref, info); + return rewriter(GetRef(scope_sref->stmt)); + } + + private: + explicit ReindexCacheReadRewriter(const StmtSRef& scope_sref, ReindexCacheStageInfo* info) + : CacheReadRewriter(scope_sref, info) { + new_indices_ = info->indices; + update_access_regions = [&](Array reads) { + Array new_reads; + for (const BufferRegion& buf_region : reads) { + if (buf_region->buffer.same_as(info_->read_buffer)) { + Array region; + for (const PrimExpr index : new_indices_) { + region.push_back(Range::FromMinExtent(index, Integer(1))); + } + new_reads.push_back(BufferRegion(info_->write_buffer, region)); + } else { + new_reads.push_back(buf_region); + } + } + return new_reads; + }; + update_match_buffers = [&](const Array match_buffers) { + Array new_match_buffers; + for (const MatchBufferRegion& match_buffer_region : match_buffers) { + BufferRegion source = match_buffer_region->source; + if (source->buffer.same_as(info_->read_buffer)) { + Array region; + for (const PrimExpr index : new_indices_) { + region.push_back(Range::FromMinExtent(index, Integer(1))); + } + new_match_buffers.push_back(MatchBufferRegion(match_buffer_region->buffer, + BufferRegion(info_->write_buffer, region))); + } else { + new_match_buffers.push_back(match_buffer_region); + } + } + return new_match_buffers; + }; + } + + PrimExpr VisitExpr_(const BufferLoadNode* load) final { + if (load->buffer.same_as(info_->read_buffer) && current_block_consumes) { + ObjectPtr n = make_object(*load); + n->buffer = info_->write_buffer; + n->indices = new_indices_; + return PrimExpr(n); + } + return ExprMutator::VisitExpr_(load); + } + + /*! \brief The indices to use for new buffer. */ + Array new_indices_; +}; + +class ReindexCacheWriteRewriter; + /*! \brief Mutator for CacheWrite */ class CacheWriteRewriter : public StmtExprMutator { public: @@ -742,7 +956,14 @@ class CacheWriteRewriter : public StmtExprMutator { private: explicit CacheWriteRewriter(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, CacheStageInfo* info) - : scope_sref_(scope_sref), writer_block_sref_(writer_block_sref), info_(info) {} + : scope_sref_(scope_sref), writer_block_sref_(writer_block_sref), info_(info) { + update_access_regions = [&](Array regions) { + return ReplaceBuffer(regions, info_->write_buffer, info_->read_buffer); + }; + update_match_buffers = [&](Array match_buffers) { + return ReplaceBuffer(match_buffers, info_->write_buffer, info_->read_buffer); + }; + } Stmt VisitStmt_(const ForNode* loop) final { Stmt stmt = StmtMutator::VisitStmt_(loop); @@ -756,7 +977,7 @@ class CacheWriteRewriter : public StmtExprMutator { return stmt; } - Stmt VisitStmt_(const BlockNode* block) final { + Stmt VisitStmt_(const BlockNode* block) override { Block old_stmt = GetRef(block); // Check if this block is one of the specified cache consumers. @@ -765,12 +986,13 @@ class CacheWriteRewriter : public StmtExprMutator { const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref); Block consumer_block = GetRef(consumer_node); if (old_stmt.same_as(consumer_block)) { - Array reads = - ReplaceBuffer(block->reads, info_->write_buffer, info_->read_buffer); - Array match_buffers = - ReplaceBuffer(block->match_buffers, info_->write_buffer, info_->read_buffer); - if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { + Array writes = update_access_regions(block->writes); + Array reads = update_access_regions(block->reads); + Array match_buffers = update_match_buffers(block->match_buffers); + if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || + !match_buffers.same_as(block->match_buffers)) { auto n = CopyOnWrite(block); + n->writes = std::move(writes); n->reads = std::move(reads); n->match_buffers = std::move(match_buffers); n->body = VisitStmt(block->body); @@ -809,10 +1031,9 @@ class CacheWriteRewriter : public StmtExprMutator { } } else { // Since cache_write changes the block, we need to update the buffer it writes - auto writes = ReplaceBuffer(block->writes, info_->write_buffer, info_->read_buffer); - auto reads = ReplaceBuffer(block->reads, info_->write_buffer, info_->read_buffer); - auto match_buffers = - ReplaceBuffer(block->match_buffers, info_->write_buffer, info_->read_buffer); + auto writes = update_access_regions(block->writes); + auto reads = update_access_regions(block->reads); + auto match_buffers = update_match_buffers(block->match_buffers); if (!writes.same_as(block->writes) || !reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { ObjectPtr n = make_object(*stmt.as()); @@ -826,7 +1047,7 @@ class CacheWriteRewriter : public StmtExprMutator { return std::move(stmt); } - Stmt VisitStmt_(const BufferStoreNode* store) final { + Stmt VisitStmt_(const BufferStoreNode* store) override { BufferStore stmt = Downcast(StmtMutator::VisitStmt_(store)); if (stmt->buffer.same_as(info_->write_buffer)) { auto n = CopyOnWrite(stmt.get()); @@ -837,7 +1058,7 @@ class CacheWriteRewriter : public StmtExprMutator { } } - PrimExpr VisitExpr_(const BufferLoadNode* load) final { + PrimExpr VisitExpr_(const BufferLoadNode* load) override { if (load->buffer.same_as(info_->write_buffer)) { ObjectPtr n = make_object(*load); n->buffer = info_->read_buffer; @@ -870,6 +1091,93 @@ class CacheWriteRewriter : public StmtExprMutator { CacheStageInfo* info_; /*! \brief Whether the current node is under the given block. */ bool under_writer_block_{false}; + /*! \brief function to update read/write region of block being cache write.*/ + std::function(Array)> update_access_regions; + /*! \brief function to update match buffers of block being cache write.*/ + std::function(Array)> update_match_buffers; + + friend ReindexCacheWriteRewriter; +}; + +/*! \brief Mutator for ReindexCacheWrite. */ +class ReindexCacheWriteRewriter : public CacheWriteRewriter { + public: + /*! + * \brief Rewrite the AST and add a cache_write stage with the information provided. + * \param scope_sref The parent scope of this mutation. + * \param writer_block_sref The only writer block in the scope. + * \param info The cache stage information. + * \return The new AST rooting at the original parent scope. + */ + static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, + ReindexCacheStageInfo* info) { + ReindexCacheWriteRewriter rewriter(scope_sref, writer_block_sref, info); + return rewriter(GetRef(scope_sref->stmt)); + } + + private: + explicit ReindexCacheWriteRewriter(const StmtSRef& scope_sref, const StmtSRef& writer_block_sref, + ReindexCacheStageInfo* info) + : CacheWriteRewriter(scope_sref, writer_block_sref, info) { + new_indices_ = info->indices; + update_access_regions = [&](Array reads) { + Array new_reads; + for (const BufferRegion& buf_region : reads) { + if (buf_region->buffer.same_as(info_->write_buffer)) { + Array region; + for (const PrimExpr index : new_indices_) { + region.push_back(Range::FromMinExtent(index, Integer(1))); + } + new_reads.push_back(BufferRegion(info_->read_buffer, region)); + } else { + new_reads.push_back(buf_region); + } + } + return new_reads; + }; + update_match_buffers = [&](const Array match_buffers) { + Array new_match_buffers; + for (const MatchBufferRegion& match_buffer_region : match_buffers) { + BufferRegion source = match_buffer_region->source; + if (source->buffer.same_as(info_->write_buffer)) { + Array region; + for (const PrimExpr index : new_indices_) { + region.push_back(Range::FromMinExtent(index, Integer(1))); + } + new_match_buffers.push_back(MatchBufferRegion(match_buffer_region->buffer, + BufferRegion(info_->read_buffer, region))); + } else { + new_match_buffers.push_back(match_buffer_region); + } + } + return new_match_buffers; + }; + } + + Stmt VisitStmt_(const BufferStoreNode* store) final { + BufferStore stmt = Downcast(StmtMutator::VisitStmt_(store)); + if (stmt->buffer.same_as(info_->write_buffer)) { + auto n = CopyOnWrite(stmt.get()); + n->buffer = info_->read_buffer; + n->indices = new_indices_; + return Stmt(n); + } else { + return std::move(stmt); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* load) final { + if (load->buffer.same_as(info_->write_buffer)) { + ObjectPtr n = make_object(*load); + n->buffer = info_->read_buffer; + n->indices = new_indices_; + return PrimExpr(n); + } + return ExprMutator::VisitExpr_(load); + } + + /*! \brief The indices to use for new buffer. */ + Array new_indices_; }; /*! @@ -898,7 +1206,9 @@ Buffer CreateReindexBuffer(const Buffer& buffer, const Array& block_ite return Buffer(new_buffer); } -/*! \brief The schedule error that the target is not a leaf block. */ +/*! + * \brief The schedule error that the target is not a leaf block. + */ class NotLeafBlockError : public ScheduleError { public: NotLeafBlockError(IRModule mod, Block block) : mod_(std::move(mod)), block_(std::move(block)) {} @@ -1297,6 +1607,293 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu return result_block_sref; } +Array GetLoopsUnderScope(const StmtSRef& block_sref, const StmtSRef& top_sref) { + std::vector result; + for (StmtSRefNode* parent = block_sref->parent; parent && parent->stmt->IsInstance(); + parent = parent->parent) { + if (parent == top_sref.get()) break; + result.push_back(GetRef(parent)); + } + return {result.rbegin(), result.rend()}; +} + +/*! + * \brief The schedule error that block iter vars appears in old buffer and new + * allocated cache buffer does not match. + */ +class ReindexCacheReadWriteNotMatchError : public ScheduleError { + public: + ReindexCacheReadWriteNotMatchError(IRModule mod, Block block, Var var, + Array old_indices, Array new_indices, + bool is_cache_read, bool appears_in_old) + : mod_(std::move(mod)), block_(std::move(block)), var_(std::move(var)) { + primitive_name_ = is_cache_read ? "reindex_cache_read" : "reindex_cache_write"; + if (appears_in_old) { + appears_indices_ = std::move(old_indices); + other_indices_ = std::move(new_indices); + } else { + appears_indices_ = std::move(new_indices); + other_indices_ = std::move(old_indices); + } + } + String FastErrorString() const final { + return "ScheduleError: the block itervars appeared in lhs and rhs of reindex cache stage do " + "not match."; + } + + String DetailRenderTemplate() const final { + std::stringstream s; + s << "Error when applying " << primitive_name_ << " on block {0}, the block itervar " << var_ + << " appears in " << appears_indices_ << ", but not in " << other_indices_ << "."; + return String(s.str()); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod_; + String primitive_name_; + Block block_; + Var var_; + Array appears_indices_; + Array other_indices_; +}; + +/*! + * \brief Update ReindexCacheStageInfo and create new cache buffer, used in + * both ReindexCacheRead and ReindexCacheWrite. + * \param info Pointer to ReindexCacheStageInfo + * \param mod The IRModule. + * \param block_sref The StmtSRef to the block we are working on. + * \param storage_scope The storage scope of cache buffer (e.g. "shared"/"local"). + * \param index_map The user defined indices. + * \param blok The block we are working on. + * \param realize The BlockRealize this block belongs to. + * \param old_buffer The buffer whose buffer access need to be rewriten. + * \param cache_region The old buffer access region. + */ +template +void CollectReindexCacheStageInfoAndCreateBuffer( + ReindexCacheStageInfo* info, const IRModule& mod, const StmtSRef& block_sref, + const String& storage_scope, const IndexMap& index_map, const Block& block, + const BlockRealize& realize, const Buffer& old_buffer, const BufferRegion& cache_region) { + Array block_iter_vars, block_shape; + for (const IterVar& iter_var : block->iter_vars) { + block_iter_vars.push_back(iter_var); + block_shape.push_back(iter_var->dom->extent); + } + Array new_indices = index_map->MapIndices(block_iter_vars); + Array new_shape = index_map->MapShape(block_shape); + info->indices = new_indices; + + // Step 5. Update CacheTouchedInfo + VarUseDefAnalyzer collector_old(/*defined_vars=*/{}); + Array old_indices; + for (const Range& range : cache_region->region) { + collector_old(range->min); + old_indices.push_back(range->min); + } + + arith::Analyzer analyzer; + + VarUseDefAnalyzer collector_new(/*defined_vars=*/{}); + for (const PrimExpr& idx : new_indices) { + collector_new(idx); + } + + VarUseDefAnalyzer collector_iter_values(/*defined_vars=*/{}); + for (size_t i = 0; i < block->iter_vars.size(); ++i) { + const IterVar& block_iter_var = block->iter_vars[i]; + const PrimExpr& block_iter_value = realize->iter_values[i]; + bool appears_in_new = collector_new.use_count_.count(block_iter_var->var.get()); + bool appears_in_old = collector_old.use_count_.count(block_iter_var->var.get()); + if (appears_in_new != appears_in_old) { + throw ReindexCacheReadWriteNotMatchError(mod, block, block_iter_var->var, old_indices, + new_indices, is_cache_read, appears_in_old); + } + if (appears_in_new) { + info->block_iter_vars.push_back(block_iter_var); + info->block_iter_values.push_back(block_iter_value); + collector_iter_values(block_iter_value); + } + } + + for (const StmtSRef& loop_sref : GetLoopsUnderScope(block_sref, info->loc_sref)) { + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); + if (collector_iter_values.use_count_.count(loop->loop_var.get())) { + info->loop_vars.push_back(loop->loop_var); + info->loop_ranges.push_back(Range::FromMinExtent(loop->min, loop->extent)); + } + } + + // Create new buffer + ObjectPtr new_buffer = make_object(*old_buffer.get()); + ObjectPtr new_var = make_object(*old_buffer->data.get()); + const auto* ptr_type = TVM_TYPE_AS(old_buffer->data->type_annotation, PointerTypeNode); + new_var->type_annotation = PointerType(ptr_type->element_type, storage_scope); + new_buffer->data = Var(new_var->name_hint + "_" + storage_scope, new_var->type_annotation); + new_buffer->name = old_buffer->name + "_" + storage_scope; + new_buffer->shape = new_shape; + + if (is_cache_read) { + info->write_buffer = Buffer(new_buffer); + info->alloc = info->write_buffer; + } else { + info->read_buffer = Buffer(new_buffer); + info->alloc = info->read_buffer; + } +} + +/*! \brief Check whether given cache_region is a single point access. */ +template +void CheckSinglePoint(ScheduleState self, const Block& block, const BufferRegion& cache_region) { + bool single_point = true; + for (const Range& range : cache_region->region) { + const auto* ext_int = range->extent.as(); + if (!ext_int || ext_int->value != 1) { + single_point = false; + } + } + if (!single_point) { + throw NotSinglePointAccess(self->mod, block, cache_region, is_cache_read); + } +} + +StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, + const String& storage_scope, const IndexMap& index_map) { + /*! + * Check: + * - The index is in the array of block reading region + * - There is at most one block who write the buffer in the scope + * + * Mutate: + * - Allocate new cache buffer under the current scope. + * - Find the lowest ancestor of the block and ANY ONE of the consumers blocks. + * - Copy the buffer with the consumed region. + */ + + // Step 0. Check the input storage scope. + CheckStorageScope(self, storage_scope); + + // Step 1. Check index, getting the target buffer and the parent scope + Block block = GetRef(TVM_SREF_TO_BLOCK(block_sref)); + BlockRealize realize = GetBlockRealize(self, block_sref); + Buffer read_buffer = GetNthAccessBuffer(self, block, read_buffer_index, BufferIndexType::kRead); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + + // Step 2. Create CacheStageInfo + ReindexCacheStageInfo info; + info.read_buffer = read_buffer; + info.consumer_blocks.insert(block_sref); + + // Step 3. Update cache stage info. + Optional maybe_region = GetBufferRegionFromBuffer(block->reads, read_buffer); + ICHECK(maybe_region.defined()) << read_buffer + << " should appear in the block's read region: " << block->reads; + BufferRegion cache_region = maybe_region.value(); + if (Optional _write_block_sref = GetOnlyWriteBlock(self, scope_sref, read_buffer)) { + // Case 1. The buffer is written inside the block. + StmtSRef write_block_sref = _write_block_sref.value(); + // Find the producing region + StmtSRef parent_sref = GetRef(write_block_sref->parent); + // Detect insert position + CacheLocDetector::Detect(self, write_block_sref, scope_sref, &info); + } else { + // Case 2. The buffer is the input block for the scope. + info.loc_sref = scope_sref; + info.loc_pos = 0; + } + + // Step 4. Check whether cache region is a single point. + CheckSinglePoint(self, block, cache_region); + + // Step 5. Collect ReindexCacheStageInfo and create new buffer. + CollectReindexCacheStageInfoAndCreateBuffer( + &info, self->mod, block_sref, storage_scope, index_map, block, realize, read_buffer, + cache_region); + + // Step 6. Making new cache stage block and rewrite readers. + Block cache_read_stage = + MakeReindexCacheStage(/*cache_region=*/cache_region, + /*info=*/&info, + /*storage_scope=*/storage_scope); + Stmt new_scope = ReindexCacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info); + + // Step 7. Replacing and updating flags. + self->Replace(scope_sref, new_scope, info.block_reuse); + StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get()); + BlockInfo& block_info = self->block_info[result_block_sref]; + block_info.affine_binding = CalculateAffineFlag(self, result_block_sref); + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + return result_block_sref; +} + +StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, + const String& storage_scope, const IndexMap& index_map) { + /*! + * Check: + * - The index is in the array of block reading region + * - There is only one block who write the buffer in the scope + * + * Mutate: + * - Allocate new cache buffer under the current scope. + * - Find the lowest ancestor of the block and ANY ONE of the producer blocks. + * - Copy the buffer with the consumed region. + */ + + // Step 0. Check the input storage scope. + CheckStorageScope(self, storage_scope); + + // Step 1. Checking index, getting the target buffer and the parent scope + Block block = GetRef(TVM_SREF_TO_BLOCK(block_sref)); + BlockRealize realize = GetBlockRealize(self, block_sref); + Buffer write_buffer = + GetNthAccessBuffer(self, block, write_buffer_index, BufferIndexType::kWrite); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + + // Step 2. Creating CacheStageInfo + ReindexCacheStageInfo info; + info.write_buffer = write_buffer; + LOG(INFO) << block->name_hint; + info.consumer_blocks.insert(block_sref); + + // Step 3. Check the only writer block. + ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get()); + + // Step 4. Find the producing region and insert position + Optional maybe_region = GetBufferRegionFromBuffer(block->writes, write_buffer); + ICHECK(maybe_region.defined()) << write_buffer << " should appear in the block's write region"; + StmtSRef parent_sref = GetRef(block_sref->parent); + // Detect insert position + CacheLocDetector::Detect(self, block_sref, scope_sref, &info); + BufferRegion cache_region = maybe_region.value(); + + CollectReindexCacheStageInfoAndCreateBuffer( + &info, self->mod, block_sref, storage_scope, index_map, block, realize, write_buffer, + cache_region); + + // Step 5. Check whether cache region is a single point. + CheckSinglePoint(self, block, cache_region); + + // Step 6. Making new cache stage block and rewrite readers. + Block cache_write_stage = + MakeReindexCacheStage(/*cache_region=*/cache_region, + /*info=*/&info, + /*storage_scope=*/storage_scope); + Stmt new_scope = ReindexCacheWriteRewriter::Rewrite( + /*scope_sref=*/scope_sref, + /*writer_block_sref=*/block_sref, /*info=*/&info); + + // Step 7. Replacing and updating flags. + self->Replace(scope_sref, new_scope, info.block_reuse); + StmtSRef result_block_sref = self->stmt2ref.at(cache_write_stage.get()); + BlockInfo& block_info = self->block_info[result_block_sref]; + block_info.affine_binding = CalculateAffineFlag(self, result_block_sref); + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + return result_block_sref; +} + /*! \brief The schedule error that the target block doesn't both read&write target buffer. */ class NotReadWriteError : public ScheduleError { public: @@ -1606,9 +2203,70 @@ struct ReIndexTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct ReindexCacheReadTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReindexCacheRead"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, IndexMap index_map, + Integer read_buffer_index, String storage_scope) { + return sch->ReindexCacheRead(block, read_buffer_index->value, storage_scope, index_map); + } + + static String UnpackedAsPython(Array outputs, String block, IndexMap index_map, + Integer read_buffer_index, String storage_scope) { + PythonAPICall py("reindex_cache_read"); + py.Input("block", block); + py.Input("read_buffer_index", read_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.Input("index_map", index_map->ToPythonString()); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct ReindexCacheWriteTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReindexCacheWrite"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, IndexMap index_map, + Integer write_buffer_index, String storage_scope) { + return sch->ReindexCacheWrite(block, write_buffer_index->value, storage_scope, index_map); + } + + static String UnpackedAsPython(Array outputs, String block, IndexMap index_map, + Integer write_buffer_index, String storage_scope) { + PythonAPICall py("reindex_cache_write"); + py.Input("block", block); + py.Input("write_buffer_index", write_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.Input("index_map", index_map->ToPythonString()); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(CacheReadTraits); TVM_REGISTER_INST_KIND_TRAITS(CacheWriteTraits); TVM_REGISTER_INST_KIND_TRAITS(CacheInplaceTraits); TVM_REGISTER_INST_KIND_TRAITS(ReIndexTraits); +TVM_REGISTER_INST_KIND_TRAITS(ReindexCacheReadTraits); +TVM_REGISTER_INST_KIND_TRAITS(ReindexCacheWriteTraits); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 4177d916486b..cb8b5a1d7787 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -179,6 +179,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead") .set_body_method(&ScheduleNode::CacheRead); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") .set_body_method(&ScheduleNode::CacheWrite); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReindexCacheRead") + .set_body_method(&ScheduleNode::ReindexCacheRead); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReindexCacheWrite") + .set_body_method(&ScheduleNode::ReindexCacheWrite); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheInplace") .set_body_method(&ScheduleNode::CacheInplace); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheIndex") diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index dba34c2ca3f3..a5cb66a0cb44 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -309,6 +309,38 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer return result; } +BlockRV TracedScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope, + const IndexMap& index_map) { + BlockRV result = + ConcreteScheduleNode::ReindexCacheRead(block_rv, read_buffer_index, storage_scope, index_map); + + static const InstructionKind& kind = InstructionKind::Get("ReindexCacheRead"); + trace_->Append( + /*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv, index_map}, + /*attrs=*/{Integer(read_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} + +BlockRV TracedScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope, + const IndexMap& index_map) { + BlockRV result = ConcreteScheduleNode::ReindexCacheWrite(block_rv, write_buffer_index, + storage_scope, index_map); + + static const InstructionKind& kind = InstructionKind::Get("ReindexCacheWrite"); + trace_->Append( + /*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv, index_map}, + /*attrs=*/{Integer(write_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} + Array TracedScheduleNode::CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) { Array result = diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 7bd83855557d..1fcba9806380 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -76,6 +76,10 @@ class TracedScheduleNode : public ConcreteScheduleNode { const Array consumer_blocks = {}) final; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope, const Array consumer_blocks = {}) final; + BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope, const IndexMap& index_map) final; + BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope, const IndexMap& index_map) final; Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) final; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index be91505f3d15..cf75768ec0e3 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -59,6 +59,58 @@ def elementwise_shape_int64(a: T.handle, c: T.handle) -> None: C[vi, vj] = B[vi, vj] + 1.0 +@T.prim_func +def elementwise_reindex_cache_read( + A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32") +): + B = T.alloc_buffer((128, 128)) + B_shared = T.alloc_buffer((128, 64, 2), scope="shared") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i, j in T.grid(128, 128): + with T.block("B_shared"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(B_shared[vj, vi // 2, vi % 2]) + B_shared[vj, vi // 2, vi % 2] = B[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B_shared[vj, vi // 2, vi % 2]) + T.writes(C[vi, vj]) + C[vi, vj] = B_shared[vj, vi // 2, vi % 2] + T.float32(1) + + +@T.prim_func +def elementwise_reindex_cache_write( + A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32") +): + B = T.alloc_buffer((128, 128)) + B_shared = T.alloc_buffer((128, 128), scope="shared") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B_shared[vj, vi]) + B_shared[vj, vi] = A[vi, vj] * T.float32(2) + for i, j in T.grid(128, 128): + with T.block("B_shared"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B_shared[vj, vi]) + T.writes(B[vi, vj]) + B[vi, vj] = B_shared[vj, vi] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + T.float32(1) + + @T.prim_func def func_nested_seq(b: T.handle, c: T.handle) -> None: A = T.alloc_buffer((128, 128)) @@ -216,6 +268,39 @@ def func_multi_consumer() -> None: C[vi] = A[vi] +@T.prim_func +def reindex_cache_read_multi_consumer() -> None: + A = T.alloc_buffer((128,)) + B = T.alloc_buffer((128,)) + C = T.alloc_buffer((128,)) + A_shared = T.alloc_buffer((4, 32), scope="shared") + for i in range(8): + for j in range(16): + with T.block("A"): + vi = T.axis.spatial(128, i * 16 + j) + T.reads() + T.writes(A[vi]) + A[vi] = T.float32(1) + for j in range(16): + with T.block("A_shared"): + vi = T.axis.spatial(128, i * 16 + j) + T.reads(A[vi]) + T.writes(A_shared[vi // 32, vi % 32]) + A_shared[vi // 32, vi % 32] = A[vi] + for j in range(16): + with T.block("B"): + vi = T.axis.spatial(128, i * 16 + j) + T.reads(A_shared[vi // 32, vi % 32]) + T.writes(B[vi]) + B[vi] = A_shared[vi // 32, vi % 32] + T.float32(1) + for i in range(128): + with T.block("C"): + vi = T.axis.spatial(128, i) + T.reads(A[vi]) + T.writes(C[vi]) + C[vi] = A[vi] + + @T.prim_func def func_multi_producer() -> None: A = T.alloc_buffer((128)) @@ -1336,5 +1421,60 @@ def test_cache_write_allocate_const(): verify_trace_roundtrip(sch=sch, mod=cache_write_allocate_const) +def test_reindex_cache_read(): + sch = tir.Schedule(elementwise, debug_mask="all") + sch.reindex_cache_read("C", 0, "shared", lambda i, j: (j, i // 2, i % 2)) + tvm.ir.assert_structural_equal(elementwise_reindex_cache_read, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_reindex_cache_read_multi_consumer(): + sch = tir.Schedule(func_multi_consumer) + sch.reindex_cache_read("B", 0, "shared", lambda i: (i // 32, i % 32)) + tvm.ir.assert_structural_equal(reindex_cache_read_multi_consumer, sch.mod["main"]) + # NOTE(zihao): we do not verify trace roundtrip because of in set analysis issues. + + +def test_reindex_cache_read_fail_not_match(): + sch = tir.Schedule(elementwise, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.reindex_cache_read( + "C", + 0, + "shared", + lambda i, j: j * 2, + ) + + +def test_reindex_cache_read_faile_not_single_point(): + sch = tir.Schedule(access_under_scope, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.reindex_cache_read("scope", 0, "shared", lambda i, j: (i, j)) + + +def test_reindex_cache_write(): + sch = tir.Schedule(elementwise, debug_mask="all") + sch.reindex_cache_write("B", 0, "shared", lambda i, j: (j, i)) + tvm.ir.assert_structural_equal(elementwise_reindex_cache_write, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_reindex_cache_write_fail_not_match(): + sch = tir.Schedule(elementwise, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.reindex_cache_write( + "B", + 0, + "shared", + lambda i, j: i, + ) + + +def test_reindex_cache_write_fail_not_single_point(): + sch = tir.Schedule(access_under_scope, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.reindex_cache_write("scope", 0, "shared", lambda i, j: (i, j)) + + if __name__ == "__main__": tvm.testing.main()