From e484858c846a197c88580b60dd3c58508490d111 Mon Sep 17 00:00:00 2001 From: Min Chen Date: Thu, 22 Sep 2022 03:35:01 +0000 Subject: [PATCH 1/4] [TIR][Schedule] Add cache_buffer primitive to cache opaque buffer --- include/tvm/tir/schedule/schedule.h | 9 + python/tvm/tir/schedule/schedule.py | 87 +++++++ src/tir/schedule/concrete_schedule.cc | 13 + src/tir/schedule/concrete_schedule.h | 2 + src/tir/schedule/primitive.h | 13 +- .../schedule/primitive/cache_read_write.cc | 237 +++++++++++++++++- src/tir/schedule/schedule.cc | 2 + src/tir/schedule/traced_schedule.cc | 16 ++ src/tir/schedule/traced_schedule.h | 2 + .../test_tir_schedule_cache_read_write.py | 58 +++++ 10 files changed, 434 insertions(+), 5 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 049f063240df..51e211100a29 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -403,6 +403,15 @@ class ScheduleNode : public runtime::Object { */ virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) = 0; + /*! + * \brief Create 2 blocks that read&write a buffer region into a read/write cache. + * \param block_rv The block operates on the target buffer. + * \param read_buffer_index The index of the buffer in block's read region. + * \param storage_scope The target storage scope + * \return The reindex stage block. + */ + virtual Array CacheBuffer(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) = 0; /*! * \brief Create a block that read/write a buffer region into a read/write cache with reindexing. * The layout of the cache will be the same as by the iterators of the block that reads/writes the diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 27171aca411b..22a452a9f08d 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1189,6 +1189,93 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: self, block, write_buffer_index, storage_scope ) + @type_checked + def cache_buffer( + self, + block: Union[BlockRV, str], + read_buffer_index: Union[int, str, Buffer], + storage_scope: str, + ) -> List[BlockRV]: + """Create blocks that reads & write a buffer region into a cache block. + + Parameters + ---------- + block : Union[BlockRV, str] + The producer block of the target buffer. + + read_buffer_index: int + The index of the buffer in block's read region, the unique + name of a read buffer in the block, or a Buffer object + that is within the blocks read region. + + storage_scope: str + The target storage scope. + + + Returns + ------- + cached_blocks : List[BlockRV] + The blocks of the cache stage, read cache first, write cache second + + Examples + -------- + Before cache_buffer, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_cache_buffer(data_io: T.Buffer[(64), "int32"]): + for i0 in T.serial(1): + with T.block("A"): + T.reads(data_io[:64]) + T.writes(data_io[:64]) + T.evaluate(T.call_extern("call_impl", data_io.data, dtype="")) + + Create the schedule and cache_buffer: + + .. code-block:: python + + sch = tir.Schedule(before_cache_buffer) + block_a = sch.get_block("A") + sch.cache_buffer(block_a, 0, "local") + print(sch.mod["main"].script()) + + After applying cache_buffer, the IR becomes: + + .. code-block:: python + + @T.prim_func + def cache_inplace(data_io: T.Buffer[64, "int32"]) -> None: + data_io_local = T.alloc_buffer([64], dtype="int32", scope="local") + for i0 in T.serial(1): + for ax0 in T.serial(64): + with T.block("data_io_local"): + v0 = T.axis.spatial(64, ax0) + T.reads(data_io[v0]) + T.writes(data_io_local[v0]) + data_io_local[v0] = data_io[v0] + with T.block("A"): + T.reads(data_io_local[0 : 64]) + T.writes(data_io_local[0 : 64]) + T.evaluate(T.call_extern("call_impl", data_io_local.data, dtype="")) + for ax0 in T.serial(64): + with T.block("data_io_local"): + v0 = T.axis.spatial(64, ax0) + T.reads(data_io_local[v0]) + T.writes(data_io[v0]) + data_io[v0] = data_io_local[v0] + + """ + block = self._normalize_block_arg(block) + + if not isinstance(read_buffer_index, int): + _, read_buffer_index, _ = self._normalize_buffer_arg( + block, read_buffer_index, required_buffer_type="read" + ) + return _ffi_api.ScheduleCacheBuffer( # type: ignore # pylint: disable=no-member + self, block, read_buffer_index, storage_scope + ) + @type_checked def reindex( self, diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 8cfbadf65012..346c597b38a0 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -561,6 +561,19 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff return CreateRV(result); } +Array ConcreteScheduleNode::CacheBuffer(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) { + Array results; + TVM_TIR_SCHEDULE_BEGIN(); + results = tir::CacheBuffer(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope); + TVM_TIR_SCHEDULE_END("cache-buffer", this->error_render_level_); + this->state_->DebugVerify(); + Array return_blocks; + return_blocks.push_back(CreateRV(results[0])); + return_blocks.push_back(CreateRV(results[1])); + return return_blocks; +} + BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type) { StmtSRef result{nullptr}; diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 59a9e3752859..c1ddf1131bc3 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -116,6 +116,8 @@ class ConcreteScheduleNode : public ScheduleNode { const Array consumer_blocks = {}) override; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) override; + Array CacheBuffer(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) override; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type) override; /******** Schedule: Compute location ********/ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 21388ff132ae..365b2175008d 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -267,6 +267,17 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r */ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, const String& storage_scope); +/*! + *! + * \brief Create 2 blocks that read&write a buffer region into a read/write cache. + * \param self The state of the schedule + * \param block_sref The block operates on the target buffer. + * \param read_buffer_index The index of the buffer in block's read region. + * \param storage_scope The target storage scope + * \return The reindex stage block. + */ +TVM_DLL Array CacheBuffer(ScheduleState self, const StmtSRef& block_sref, + int read_buffer_index, const String& storage_scope); /*! *! * \brief Create a block that read/write a buffer region into a read/write cache with reindexing. @@ -275,7 +286,7 @@ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int * 1) There is only one block who reads/writes the target buffer * 2) There is only one buffer load/store of this buffer in the block * \param self The state of the schedule - * \param block_rv The block operates on the target buffer. + * \param block_sref The block operates on the target buffer. * \param buffer_index The index of the buffer in block's read or write region. * \param buffer_index_type The type of the buffer index, kRead or kWrite. * \return The reindex stage block. diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index e9583adbbaa9..9e39c88289cf 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -500,6 +500,92 @@ class CacheLocDetector : public StmtVisitor { int loc_pos_{-1}; }; +/*! \brief Detect the insertion position of the new cache stage */ +class CacheBufferLocDetector : public StmtVisitor { + public: + /*! + * \brief Detect the insertion position of the cache stage, and write the position into the + * CacheStageInfo \param self The state of the schedule \param block_sref The sref of the unique + * block of the buffer being applied cache_buffer \param scope_sref The sref + * of the scope block of the cached block \param info The cache stage info. + */ + static void Detect(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_sref, CacheStageInfo* info) { + CacheBufferLocDetector detector(self, block_sref, scope_sref); + detector(GetRef(scope_sref->stmt)); + info->loc_sref = detector.loc_sref_; + info->loc_pos = detector.loc_pos_; + } + + private: + /*! + * \brief Constructor + * \param self The state of the schedule + * \param block_sref The sref of the unique writer block of the buffer being applied cache_buffer + * \param scope_sref The sref of the scope block of the cached block + */ + CacheBufferLocDetector(const ScheduleState self, const StmtSRef& block_sref, + const StmtSRef& scope_sref) + : self_(self), block_sref_(block_sref), scope_sref_(scope_sref) {} + + void VisitStmt_(const SeqStmtNode* seq_stmt) final { + for (size_t i = 0; i < seq_stmt->size(); ++i) { + if (loc_pos_ != -1) { + break; + } + VisitStmt(seq_stmt->seq[i]); + // `pos` can be assigned only once when we visited `block_sref` + if (visited_block_ && loc_pos_ == -1) { + // The offset of insert position from the block + loc_pos_ = i; + return; + } + } + } + + void VisitStmt_(const BlockNode* block) final { + // Only visit the current scope under buffer writer's parent block + if (block == scope_sref_->stmt) { + // The block vistied is the current parent scope + StmtVisitor::VisitStmt_(block); + // Handling cases when insert outside any loop + if (visited_block_ && !loc_sref_.defined()) { + loc_sref_ = self_->stmt2ref.at(block); + // Handling for input buffer + if (loc_pos_ == -1) { + loc_pos_ = 0; + } + } + } else if (block_sref_->stmt == block) { + visited_block_ = true; + } + } + + void VisitStmt_(const ForNode* loop) final { + StmtVisitor::VisitStmt_(loop); + if (visited_block_ && !loc_sref_.defined()) { + loc_sref_ = self_->stmt2ref.at(loop); + if (loc_pos_ == -1) { + loc_pos_ = 0; + } + } + } + + private: + /*! \brief The schedule class */ + const ScheduleState self_; + /*! \brief The dominate block which write the buffer */ + const StmtSRef& block_sref_; + /*! \brief The parent scope of the dominate block */ + const StmtSRef& scope_sref_; + /*! \brief The flag whether we have visited the target block */ + bool visited_block_{false}; + /*! \brief The AST node whose body is where the cache stage should be inserted */ + StmtSRef loc_sref_{nullptr}; + /*! \brief The index to insert the cache_read/cache_write stage */ + int loc_pos_{-1}; +}; + /*! \brief Mutator for CacheRead. */ class CacheReadRewriter : public StmtExprMutator { public: @@ -563,8 +649,17 @@ class CacheReadRewriter : public StmtExprMutator { if (block == scope_sref_->stmt) { // If so, put buffer allocation on the parent scope ObjectPtr n = make_object(*stmt.as()); - n->alloc_buffers.push_back(info_->alloc); - stmt = Block(n); + bool alloc_buffer_exists = false; + for (const Buffer& it : n->alloc_buffers) { + if (it.same_as(info_->alloc)) { + alloc_buffer_exists = true; + } + } + // In cache_buffer case, alloc_buffer may be already exits. + if (!alloc_buffer_exists) { + n->alloc_buffers.push_back(info_->alloc); + stmt = Block(n); + } } else { // Otherwise, update read regions and match_buffers // Only make this change if the block is one of the specified consumers. @@ -670,8 +765,17 @@ class CacheWriteRewriter : public StmtExprMutator { // Put buffer allocation on the parent scope if (block == scope_sref_->stmt) { ObjectPtr n = make_object(*stmt.as()); - n->alloc_buffers.push_back(info_->alloc); - stmt = Block(n); + bool alloc_buffer_exists = false; + for (const Buffer& it : n->alloc_buffers) { + if (it.same_as(info_->alloc)) { + alloc_buffer_exists = true; + } + } + // In cache_buffer case, alloc_buffer may be already exits. + if (!alloc_buffer_exists) { + n->alloc_buffers.push_back(info_->alloc); + stmt = Block(n); + } } else { // Since cache_write changes the block, we need to update the buffer it writes auto writes = ReplaceBuffer(block->writes, info_->write_buffer, info_->read_buffer); @@ -1146,6 +1250,102 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu return result_block_sref; } +Array CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, + const String& storage_scope) { + /*! + * Do cache read then cache write + */ + + // Check 0. Check the input storage scope. + CheckStorageScope(self, storage_scope); + + // Check 1. Check index, get the target buffer and the parent scope + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + Buffer buffer = + GetNthAccessBuffer(self, GetRef(block), write_buffer_index, BufferIndexType::kRead); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); + + // Check 3. Check required region cover for cache_read + CheckRegionCover(self, scope_sref); + + Array results_block_sref; + Buffer new_buffer = WithScope(buffer, storage_scope); + + // Do cache read + // Cache read step 0. Create CacheStageInfo + CacheStageInfo info; + info.read_buffer = buffer; + // Create the corresponding buffer to be written for cache_read + info.write_buffer = new_buffer; + // Create the corresponding buffer allocation + info.alloc = info.write_buffer; + // Indicate which buffers should consume the cache. + info.consumer_blocks.push_back(block_sref); + + // Cache read step 1. Update cache stage info for cache_read. + BufferRegion cache_region{nullptr}; + Optional _write_block_sref = GetOnlyWriteBlock(self, scope_sref, buffer); + + StmtSRef write_block_sref = _write_block_sref.value(); + const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref); + // Find the producing region + BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, buffer).value(); + StmtSRef parent_sref = GetRef(write_block_sref->parent); + + // Detect insert position + CacheBufferLocDetector::Detect(self, write_block_sref, scope_sref, &info); + cache_region = RelaxBufferRegion(self, region, write_block_sref, parent_sref, info.loc_sref); + + // Cache read step 2. Making new cache stage block and rewrite readers. + Block cache_read_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, + /*storage_scope=*/storage_scope); + Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info); + + // Cache read step 3. Replacing and updating flags for cache read. + self->Replace(scope_sref, new_scope, info.block_reuse); + StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get()); + BlockInfo& block_info_read = self->block_info[result_block_sref]; + block_info_read.affine_binding = CalculateAffineFlag(self, result_block_sref); + block_info_read.region_cover = true; + block_info_read.scope->stage_pipeline = true; + results_block_sref.push_back(result_block_sref); + + // Do cache write + // Cache write step 0. Update cache stage info for cache_read. + info.read_buffer = new_buffer; + // Create the corresponding buffer to be written, i.e. result of cache_write + info.write_buffer = buffer; + // Create the corresponding buffer allocation + info.alloc = info.read_buffer; + info.consumer_blocks.clear(); + + // Cache write step 1. Find the producing region and insert position + region = GetBufferRegionFromBuffer(block->writes, buffer).value(); + parent_sref = GetRef(block_sref->parent); + // Detect insert position + CacheBufferLocDetector::Detect(self, block_sref, scope_sref, &info); + // insert after target block for cache write + info.loc_pos += 1; + cache_region = RelaxBufferRegion(self, region, block_sref, parent_sref, info.loc_sref); + + // Cache write step 2. Making new cache stage block and rewrite readers. + Block cache_write_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, + /*storage_scope=*/storage_scope); + new_scope = CacheWriteRewriter::Rewrite(/*scope_sref=*/scope_sref, + /*writer_block_sref=*/block_sref, /*info=*/&info); + + // Cache write step 4. Replacing and updating flags for cache write. + self->Replace(scope_sref, new_scope, info.block_reuse); + result_block_sref = self->stmt2ref.at(cache_write_stage.get()); + BlockInfo& block_info_write = self->block_info[result_block_sref]; + block_info_write.affine_binding = CalculateAffineFlag(self, result_block_sref); + block_info_write.region_cover = true; + block_info_write.scope->stage_pipeline = true; + results_block_sref.push_back(result_block_sref); + + return results_block_sref; +} + StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type) { const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); @@ -1282,6 +1482,34 @@ struct CacheWriteTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct CacheBufferTraits : public UnpackedInstTraits { + static constexpr const char* kName = "CacheBuffer"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block, + Integer read_buffer_index, String storage_scope) { + return sch->CacheBuffer(block, read_buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String block, Integer read_buffer_index, + String storage_scope) { + PythonAPICall py("cache_buffer"); + py.Input("block", block); + py.Input("read_buffer_index", read_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.OutputList(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + struct ReIndexTraits : public UnpackedInstTraits { static constexpr const char* kName = "ReIndex"; static constexpr bool kIsPure = false; @@ -1315,6 +1543,7 @@ struct ReIndexTraits : public UnpackedInstTraits { TVM_REGISTER_INST_KIND_TRAITS(CacheReadTraits); TVM_REGISTER_INST_KIND_TRAITS(CacheWriteTraits); +TVM_REGISTER_INST_KIND_TRAITS(CacheBufferTraits); TVM_REGISTER_INST_KIND_TRAITS(ReIndexTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 2f27dbb9fbf1..9df8115e3964 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -179,6 +179,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead") .set_body_method(&ScheduleNode::CacheRead); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") .set_body_method(&ScheduleNode::CacheWrite); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheBuffer") + .set_body_method(&ScheduleNode::CacheBuffer); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReIndex") .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type) { diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 9ff793dc39dd..b65f81739307 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -307,6 +307,22 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer return result; } +Array TracedScheduleNode::CacheBuffer(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) { + Array result = + ConcreteScheduleNode::CacheBuffer(block_rv, read_buffer_index, storage_scope); + Array results; + for (const BlockRV& r : result) { + results.push_back(r); + } + static const InstructionKind& kind = InstructionKind::Get("CacheBuffer"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(read_buffer_index), storage_scope}, + /*outputs=*/results)); + return result; +} + BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type) { BlockRV result = ConcreteScheduleNode::ReIndex(block_rv, buffer_index, buffer_index_type); diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 0e83b35f44e9..cfbc6176d9fc 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -76,6 +76,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { const Array consumer_blocks = {}) final; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) final; + Array CacheBuffer(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) final; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type) final; /******** Schedule: Compute location ********/ diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index 334fb988d775..bae1a99cfddf 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -241,6 +241,15 @@ def inplace_func(data_io: T.Buffer[(64), "int32"]): data_io[v0] = data_1d[v0] +@T.prim_func +def inplace_call(data_io: T.Buffer[(64), "int32"]): + for i0 in T.serial(1): + with T.block("ext_call"): + T.reads(data_io[:64]) + T.writes(data_io[:64]) + T.evaluate(T.call_extern("call_impl", data_io.data, dtype="")) + + ########## Expected function after cache_read ########## @@ -548,6 +557,42 @@ def cache_read_inplace(data_io: T.Buffer[64, "int32"]) -> None: data_io[v0] = data_1d[v0] +@T.prim_func +def cache_buffer_inplace(data_io: T.Buffer[64, "int32"]) -> None: + data_io_local = T.alloc_buffer([64], dtype="int32", scope="local") + data_io_global = T.alloc_buffer([64], dtype="int32") + data_io_global_1 = T.alloc_buffer([64], dtype="int32") + for ax0 in T.serial(64): + with T.block("data_io_global"): + v0 = T.axis.spatial(64, ax0) + T.reads(data_io[v0]) + T.writes(data_io_global[v0]) + data_io_global[v0] = data_io[v0] + for i0 in T.serial(1): + for ax0 in T.serial(64): + with T.block("data_io_local"): + v0 = T.axis.spatial(64, ax0) + T.reads(data_io_global[v0]) + T.writes(data_io_local[v0]) + data_io_local[v0] = data_io_global[v0] + with T.block("ext_call"): + T.reads(data_io_local[0:64]) + T.writes(data_io_local[0:64]) + T.evaluate(T.call_extern("call_impl", data_io_local.data, dtype="")) + for ax0 in T.serial(64): + with T.block("data_io_local"): + v0 = T.axis.spatial(64, ax0) + T.reads(data_io_local[v0]) + T.writes(data_io_global_1[v0]) + data_io_global_1[v0] = data_io_local[v0] + for ax0 in T.serial(64): + with T.block("data_io_global"): + v0 = T.axis.spatial(64, ax0) + T.reads(data_io_global_1[v0]) + T.writes(data_io[v0]) + data_io[v0] = data_io_global_1[v0] + + ########## Expected function after cache_write ########## @@ -931,6 +976,19 @@ def test_inplace_cache_read(): verify_trace_roundtrip(sch=sch, mod=inplace_func) +def test_inplace_cache_buffer(): + # cache buffer could introduce WAR, which is expected but stage pipeline property changes + debug_mask = tvm.tir.schedule.state.ScheduleDebugMask.VERIFY_SREF_TREE + sch = tvm.tir.Schedule(inplace_call, debug_mask=debug_mask) + block = sch.get_block("ext_call") + blocks = sch.cache_buffer(block, 0, "local") + block = sch.cache_read(blocks[0], 0, "global", [blocks[0]]) + block = sch.cache_write(blocks[1], 0, "global") + + tvm.ir.assert_structural_equal(cache_buffer_inplace, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=inplace_call, debug_mask=debug_mask) + + ########## Testcases for cache_write ########## From 416a6d6b95788026906371a8b9398efc5fc9ae6e Mon Sep 17 00:00:00 2001 From: Min Chen Date: Sat, 1 Oct 2022 03:24:39 +0000 Subject: [PATCH 2/4] Apply review comments. --- .../schedule/primitive/cache_read_write.cc | 32 ++++++------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 9e39c88289cf..f1c96762eb63 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -66,7 +66,7 @@ struct CacheStageInfo { /*! \brief The buffer to be written. */ Buffer write_buffer; /*! \brief The buffer allocation to be inserted into the block signature. */ - Buffer alloc; + Optional alloc; /*! \brief The AST node whose body is where the cache stage should be inserted. */ StmtSRef loc_sref; /*! \brief The index to insert the cache_read/cache_write stage. */ @@ -649,15 +649,9 @@ class CacheReadRewriter : public StmtExprMutator { if (block == scope_sref_->stmt) { // If so, put buffer allocation on the parent scope ObjectPtr n = make_object(*stmt.as()); - bool alloc_buffer_exists = false; - for (const Buffer& it : n->alloc_buffers) { - if (it.same_as(info_->alloc)) { - alloc_buffer_exists = true; - } - } // In cache_buffer case, alloc_buffer may be already exits. - if (!alloc_buffer_exists) { - n->alloc_buffers.push_back(info_->alloc); + if (info_->alloc.defined()) { + n->alloc_buffers.push_back(info_->alloc.value()); stmt = Block(n); } } else { @@ -765,15 +759,9 @@ class CacheWriteRewriter : public StmtExprMutator { // Put buffer allocation on the parent scope if (block == scope_sref_->stmt) { ObjectPtr n = make_object(*stmt.as()); - bool alloc_buffer_exists = false; - for (const Buffer& it : n->alloc_buffers) { - if (it.same_as(info_->alloc)) { - alloc_buffer_exists = true; - } - } // In cache_buffer case, alloc_buffer may be already exits. - if (!alloc_buffer_exists) { - n->alloc_buffers.push_back(info_->alloc); + if (info_->alloc.defined()) { + n->alloc_buffers.push_back(info_->alloc.value()); stmt = Block(n); } } else { @@ -1002,7 +990,7 @@ class ReIndexRewriter : public StmtExprMutator { explicit ReIndexRewriter(const StmtSRef& block_sref, CacheStageInfo* info, const std::unordered_set& covered) : block_sref_(block_sref), info_(info), covered_(covered) { - new_buffer_ = info->alloc; + new_buffer_ = info->alloc.value(); old_buffer_ = info->read_buffer.same_as(new_buffer_) ? info->write_buffer : info->read_buffer; } @@ -1014,7 +1002,7 @@ class ReIndexRewriter : public StmtExprMutator { // Insert cache stage into the loop ObjectPtr n = make_object(*stmt.as()); n->body = InsertCacheStage(n->body, info_->loc_pos, info_->cache_stage); - n->alloc_buffers.push_back(info_->alloc); + n->alloc_buffers.push_back(info_->alloc.value()); stmt = Block(n); info_->block_reuse.Set(old_stmt, stmt); return std::move(stmt); @@ -1306,8 +1294,6 @@ Array CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get()); BlockInfo& block_info_read = self->block_info[result_block_sref]; block_info_read.affine_binding = CalculateAffineFlag(self, result_block_sref); - block_info_read.region_cover = true; - block_info_read.scope->stage_pipeline = true; results_block_sref.push_back(result_block_sref); // Do cache write @@ -1316,7 +1302,7 @@ Array CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int // Create the corresponding buffer to be written, i.e. result of cache_write info.write_buffer = buffer; // Create the corresponding buffer allocation - info.alloc = info.read_buffer; + info.alloc = nullptr; info.consumer_blocks.clear(); // Cache write step 1. Find the producing region and insert position @@ -1340,7 +1326,7 @@ Array CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int BlockInfo& block_info_write = self->block_info[result_block_sref]; block_info_write.affine_binding = CalculateAffineFlag(self, result_block_sref); block_info_write.region_cover = true; - block_info_write.scope->stage_pipeline = true; + block_info_write.scope->stage_pipeline = false; results_block_sref.push_back(result_block_sref); return results_block_sref; From c5e1f144a4d8757955790b444a2c3e7fdd06bc0e Mon Sep 17 00:00:00 2001 From: Min Chen Date: Sun, 9 Oct 2022 03:52:18 +0000 Subject: [PATCH 3/4] Fix API description and add more checks. --- include/tvm/tir/schedule/schedule.h | 5 +- python/tvm/tir/schedule/schedule.py | 3 +- src/tir/schedule/primitive.h | 5 +- .../schedule/primitive/cache_read_write.cc | 57 ++++++++++++------- 4 files changed, 43 insertions(+), 27 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 51e211100a29..2a4078f8c4be 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -405,10 +405,11 @@ class ScheduleNode : public runtime::Object { const String& storage_scope) = 0; /*! * \brief Create 2 blocks that read&write a buffer region into a read/write cache. - * \param block_rv The block operates on the target buffer. + * It requires the the target block both read & write the target buffer. + * \param block_rv The target block operates on the target buffer. * \param read_buffer_index The index of the buffer in block's read region. * \param storage_scope The target storage scope - * \return The reindex stage block. + * \return The cache stage blocks, cache read block together with cache write block. */ virtual Array CacheBuffer(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) = 0; diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 22a452a9f08d..94112e6473e6 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1197,11 +1197,12 @@ def cache_buffer( storage_scope: str, ) -> List[BlockRV]: """Create blocks that reads & write a buffer region into a cache block. + It requires the the target block both read & write the target buffer. Parameters ---------- block : Union[BlockRV, str] - The producer block of the target buffer. + The target block operates on the target buffer. read_buffer_index: int The index of the buffer in block's read region, the unique diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 365b2175008d..582f29b06de0 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -270,11 +270,12 @@ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int /*! *! * \brief Create 2 blocks that read&write a buffer region into a read/write cache. + * It requires the the target block both read & write the target buffer. * \param self The state of the schedule - * \param block_sref The block operates on the target buffer. + * \param block_sref The target block operates on the target buffer. * \param read_buffer_index The index of the buffer in block's read region. * \param storage_scope The target storage scope - * \return The reindex stage block. + * \return The cache stage blocks, cache read block together with cache write block. */ TVM_DLL Array CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, const String& storage_scope); diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index f1c96762eb63..27e38f1b2550 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -1238,7 +1238,27 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu return result_block_sref; } -Array CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, +/*! \brief The schedule error that the target block doesn't both read&write target buffer. */ +class NotReadWriteError : public ScheduleError { + public: + NotReadWriteError(IRModule mod, Block block, Buffer buffer) + : mod_(std::move(mod)), block_(std::move(block)), buffer_(std::move(buffer)) {} + String FastErrorString() const final { + return "ScheduleError: The target block does not both read & write target buffer."; + } + + String DetailRenderTemplate() const final { + return "The target block {0} does not both read & write target buffer {1}."; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_, buffer_}; } + IRModule mod_; + Block block_; + Buffer buffer_; +}; + +Array CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, const String& storage_scope) { /*! * Do cache read then cache write @@ -1250,12 +1270,20 @@ Array CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int // Check 1. Check index, get the target buffer and the parent scope const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Buffer buffer = - GetNthAccessBuffer(self, GetRef(block), write_buffer_index, BufferIndexType::kRead); + GetNthAccessBuffer(self, GetRef(block), read_buffer_index, BufferIndexType::kRead); StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); // Check 3. Check required region cover for cache_read CheckRegionCover(self, scope_sref); + // Check 4. Check if target block both read & write target buffer. + const BlockNode* rw_block = TVM_SREF_TO_BLOCK(block_sref); + Optional read_region = GetBufferRegionFromBuffer(rw_block->reads, buffer); + Optional write_region = GetBufferRegionFromBuffer(rw_block->writes, buffer); + if (!read_region.defined() || !write_region.defined()) { + throw NotReadWriteError(self->mod, GetRef(rw_block), buffer); + } + Array results_block_sref; Buffer new_buffer = WithScope(buffer, storage_scope); @@ -1270,22 +1298,11 @@ Array CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int // Indicate which buffers should consume the cache. info.consumer_blocks.push_back(block_sref); - // Cache read step 1. Update cache stage info for cache_read. - BufferRegion cache_region{nullptr}; - Optional _write_block_sref = GetOnlyWriteBlock(self, scope_sref, buffer); - - StmtSRef write_block_sref = _write_block_sref.value(); - const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref); - // Find the producing region - BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, buffer).value(); - StmtSRef parent_sref = GetRef(write_block_sref->parent); - - // Detect insert position - CacheBufferLocDetector::Detect(self, write_block_sref, scope_sref, &info); - cache_region = RelaxBufferRegion(self, region, write_block_sref, parent_sref, info.loc_sref); + // Cache read step 1. Detect insert position + CacheBufferLocDetector::Detect(self, block_sref, scope_sref, &info); // Cache read step 2. Making new cache stage block and rewrite readers. - Block cache_read_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, + Block cache_read_stage = MakeCacheStage(/*cache_region=*/read_region.value(), /*info=*/&info, /*storage_scope=*/storage_scope); Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info); @@ -1305,17 +1322,13 @@ Array CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int info.alloc = nullptr; info.consumer_blocks.clear(); - // Cache write step 1. Find the producing region and insert position - region = GetBufferRegionFromBuffer(block->writes, buffer).value(); - parent_sref = GetRef(block_sref->parent); - // Detect insert position + // Cache write step 1. Detect insert position CacheBufferLocDetector::Detect(self, block_sref, scope_sref, &info); // insert after target block for cache write info.loc_pos += 1; - cache_region = RelaxBufferRegion(self, region, block_sref, parent_sref, info.loc_sref); // Cache write step 2. Making new cache stage block and rewrite readers. - Block cache_write_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info, + Block cache_write_stage = MakeCacheStage(/*cache_region=*/write_region.value(), /*info=*/&info, /*storage_scope=*/storage_scope); new_scope = CacheWriteRewriter::Rewrite(/*scope_sref=*/scope_sref, /*writer_block_sref=*/block_sref, /*info=*/&info); From 633667a6c3869eea21f59d03e8a0fa0ef13d7cb1 Mon Sep 17 00:00:00 2001 From: Min Chen Date: Sun, 9 Oct 2022 04:52:25 +0000 Subject: [PATCH 4/4] Rename cache_buffer to cache_inplace. --- include/tvm/tir/schedule/schedule.h | 4 +-- python/tvm/tir/schedule/schedule.py | 17 +++++----- src/tir/schedule/concrete_schedule.cc | 6 ++-- src/tir/schedule/concrete_schedule.h | 4 +-- src/tir/schedule/primitive.h | 4 +-- .../schedule/primitive/cache_read_write.cc | 34 +++++++++---------- src/tir/schedule/schedule.cc | 4 +-- src/tir/schedule/traced_schedule.cc | 8 ++--- src/tir/schedule/traced_schedule.h | 4 +-- .../test_tir_schedule_cache_read_write.py | 10 +++--- 10 files changed, 48 insertions(+), 47 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 2a4078f8c4be..9ec2841ebd5e 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -411,8 +411,8 @@ class ScheduleNode : public runtime::Object { * \param storage_scope The target storage scope * \return The cache stage blocks, cache read block together with cache write block. */ - virtual Array CacheBuffer(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope) = 0; + virtual Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) = 0; /*! * \brief Create a block that read/write a buffer region into a read/write cache with reindexing. * The layout of the cache will be the same as by the iterators of the block that reads/writes the diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 94112e6473e6..2884e0066fab 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1190,7 +1190,7 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: ) @type_checked - def cache_buffer( + def cache_inplace( self, block: Union[BlockRV, str], read_buffer_index: Union[int, str, Buffer], @@ -1198,6 +1198,7 @@ def cache_buffer( ) -> List[BlockRV]: """Create blocks that reads & write a buffer region into a cache block. It requires the the target block both read & write the target buffer. + Mainly for inplace operation. Parameters ---------- @@ -1220,28 +1221,28 @@ def cache_buffer( Examples -------- - Before cache_buffer, in TensorIR, the IR is: + Before cache_inplace, in TensorIR, the IR is: .. code-block:: python @T.prim_func - def before_cache_buffer(data_io: T.Buffer[(64), "int32"]): + def before_cache_inplace(data_io: T.Buffer[(64), "int32"]): for i0 in T.serial(1): with T.block("A"): T.reads(data_io[:64]) T.writes(data_io[:64]) T.evaluate(T.call_extern("call_impl", data_io.data, dtype="")) - Create the schedule and cache_buffer: + Create the schedule and cache_inplace: .. code-block:: python - sch = tir.Schedule(before_cache_buffer) + sch = tir.Schedule(before_cache_inplace) block_a = sch.get_block("A") - sch.cache_buffer(block_a, 0, "local") + sch.cache_inplace(block_a, 0, "local") print(sch.mod["main"].script()) - After applying cache_buffer, the IR becomes: + After applying cache_inplace, the IR becomes: .. code-block:: python @@ -1273,7 +1274,7 @@ def cache_inplace(data_io: T.Buffer[64, "int32"]) -> None: _, read_buffer_index, _ = self._normalize_buffer_arg( block, read_buffer_index, required_buffer_type="read" ) - return _ffi_api.ScheduleCacheBuffer( # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleCacheInplace( # type: ignore # pylint: disable=no-member self, block, read_buffer_index, storage_scope ) diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 346c597b38a0..2e1a7ce77d8c 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -561,11 +561,11 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff return CreateRV(result); } -Array ConcreteScheduleNode::CacheBuffer(const BlockRV& block_rv, int write_buffer_index, - const String& storage_scope) { +Array ConcreteScheduleNode::CacheInplace(const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) { Array results; TVM_TIR_SCHEDULE_BEGIN(); - results = tir::CacheBuffer(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope); + results = tir::CacheInplace(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope); TVM_TIR_SCHEDULE_END("cache-buffer", this->error_render_level_); this->state_->DebugVerify(); Array return_blocks; diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index c1ddf1131bc3..bfdc082d4ce6 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -116,8 +116,8 @@ class ConcreteScheduleNode : public ScheduleNode { const Array consumer_blocks = {}) override; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) override; - Array CacheBuffer(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope) override; + Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) override; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type) override; /******** Schedule: Compute location ********/ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 582f29b06de0..88331fb5b9d3 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -277,8 +277,8 @@ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int * \param storage_scope The target storage scope * \return The cache stage blocks, cache read block together with cache write block. */ -TVM_DLL Array CacheBuffer(ScheduleState self, const StmtSRef& block_sref, - int read_buffer_index, const String& storage_scope); +TVM_DLL Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, + int read_buffer_index, const String& storage_scope); /*! *! * \brief Create a block that read/write a buffer region into a read/write cache with reindexing. diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 27e38f1b2550..cedcd2708bf8 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -501,17 +501,17 @@ class CacheLocDetector : public StmtVisitor { }; /*! \brief Detect the insertion position of the new cache stage */ -class CacheBufferLocDetector : public StmtVisitor { +class CacheInplaceLocDetector : public StmtVisitor { public: /*! * \brief Detect the insertion position of the cache stage, and write the position into the * CacheStageInfo \param self The state of the schedule \param block_sref The sref of the unique - * block of the buffer being applied cache_buffer \param scope_sref The sref + * block of the buffer being applied cache_inplace \param scope_sref The sref * of the scope block of the cached block \param info The cache stage info. */ static void Detect(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_sref, CacheStageInfo* info) { - CacheBufferLocDetector detector(self, block_sref, scope_sref); + CacheInplaceLocDetector detector(self, block_sref, scope_sref); detector(GetRef(scope_sref->stmt)); info->loc_sref = detector.loc_sref_; info->loc_pos = detector.loc_pos_; @@ -521,11 +521,11 @@ class CacheBufferLocDetector : public StmtVisitor { /*! * \brief Constructor * \param self The state of the schedule - * \param block_sref The sref of the unique writer block of the buffer being applied cache_buffer + * \param block_sref The sref of the unique writer block of the buffer being applied cache_inplace * \param scope_sref The sref of the scope block of the cached block */ - CacheBufferLocDetector(const ScheduleState self, const StmtSRef& block_sref, - const StmtSRef& scope_sref) + CacheInplaceLocDetector(const ScheduleState self, const StmtSRef& block_sref, + const StmtSRef& scope_sref) : self_(self), block_sref_(block_sref), scope_sref_(scope_sref) {} void VisitStmt_(const SeqStmtNode* seq_stmt) final { @@ -649,7 +649,7 @@ class CacheReadRewriter : public StmtExprMutator { if (block == scope_sref_->stmt) { // If so, put buffer allocation on the parent scope ObjectPtr n = make_object(*stmt.as()); - // In cache_buffer case, alloc_buffer may be already exits. + // In cache_inplace case, alloc_buffer may be already exits. if (info_->alloc.defined()) { n->alloc_buffers.push_back(info_->alloc.value()); stmt = Block(n); @@ -759,7 +759,7 @@ class CacheWriteRewriter : public StmtExprMutator { // Put buffer allocation on the parent scope if (block == scope_sref_->stmt) { ObjectPtr n = make_object(*stmt.as()); - // In cache_buffer case, alloc_buffer may be already exits. + // In cache_inplace case, alloc_buffer may be already exits. if (info_->alloc.defined()) { n->alloc_buffers.push_back(info_->alloc.value()); stmt = Block(n); @@ -1258,8 +1258,8 @@ class NotReadWriteError : public ScheduleError { Buffer buffer_; }; -Array CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, - const String& storage_scope) { +Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, + const String& storage_scope) { /*! * Do cache read then cache write */ @@ -1299,7 +1299,7 @@ Array CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int info.consumer_blocks.push_back(block_sref); // Cache read step 1. Detect insert position - CacheBufferLocDetector::Detect(self, block_sref, scope_sref, &info); + CacheInplaceLocDetector::Detect(self, block_sref, scope_sref, &info); // Cache read step 2. Making new cache stage block and rewrite readers. Block cache_read_stage = MakeCacheStage(/*cache_region=*/read_region.value(), /*info=*/&info, @@ -1323,7 +1323,7 @@ Array CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int info.consumer_blocks.clear(); // Cache write step 1. Detect insert position - CacheBufferLocDetector::Detect(self, block_sref, scope_sref, &info); + CacheInplaceLocDetector::Detect(self, block_sref, scope_sref, &info); // insert after target block for cache write info.loc_pos += 1; @@ -1481,8 +1481,8 @@ struct CacheWriteTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; -struct CacheBufferTraits : public UnpackedInstTraits { - static constexpr const char* kName = "CacheBuffer"; +struct CacheInplaceTraits : public UnpackedInstTraits { + static constexpr const char* kName = "CacheInplace"; static constexpr bool kIsPure = false; private: @@ -1492,12 +1492,12 @@ struct CacheBufferTraits : public UnpackedInstTraits { static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer read_buffer_index, String storage_scope) { - return sch->CacheBuffer(block, read_buffer_index->value, storage_scope); + return sch->CacheInplace(block, read_buffer_index->value, storage_scope); } static String UnpackedAsPython(Array outputs, String block, Integer read_buffer_index, String storage_scope) { - PythonAPICall py("cache_buffer"); + PythonAPICall py("cache_inplace"); py.Input("block", block); py.Input("read_buffer_index", read_buffer_index->value); py.Input("storage_scope", storage_scope); @@ -1542,7 +1542,7 @@ struct ReIndexTraits : public UnpackedInstTraits { TVM_REGISTER_INST_KIND_TRAITS(CacheReadTraits); TVM_REGISTER_INST_KIND_TRAITS(CacheWriteTraits); -TVM_REGISTER_INST_KIND_TRAITS(CacheBufferTraits); +TVM_REGISTER_INST_KIND_TRAITS(CacheInplaceTraits); TVM_REGISTER_INST_KIND_TRAITS(ReIndexTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 9df8115e3964..280d0af92a8c 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -179,8 +179,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead") .set_body_method(&ScheduleNode::CacheRead); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") .set_body_method(&ScheduleNode::CacheWrite); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheBuffer") - .set_body_method(&ScheduleNode::CacheBuffer); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheInplace") + .set_body_method(&ScheduleNode::CacheInplace); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReIndex") .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type) { diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index b65f81739307..b67b008feda4 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -307,15 +307,15 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer return result; } -Array TracedScheduleNode::CacheBuffer(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope) { +Array TracedScheduleNode::CacheInplace(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) { Array result = - ConcreteScheduleNode::CacheBuffer(block_rv, read_buffer_index, storage_scope); + ConcreteScheduleNode::CacheInplace(block_rv, read_buffer_index, storage_scope); Array results; for (const BlockRV& r : result) { results.push_back(r); } - static const InstructionKind& kind = InstructionKind::Get("CacheBuffer"); + static const InstructionKind& kind = InstructionKind::Get("CacheInplace"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{block_rv}, /*attrs=*/{Integer(read_buffer_index), storage_scope}, diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index cfbc6176d9fc..016de60726b9 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -76,8 +76,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { const Array consumer_blocks = {}) final; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) final; - Array CacheBuffer(const BlockRV& block_rv, int read_buffer_index, - const String& storage_scope) final; + Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) final; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type) final; /******** Schedule: Compute location ********/ diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index bae1a99cfddf..a237a5b75839 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -558,7 +558,7 @@ def cache_read_inplace(data_io: T.Buffer[64, "int32"]) -> None: @T.prim_func -def cache_buffer_inplace(data_io: T.Buffer[64, "int32"]) -> None: +def cache_inplace_buffer(data_io: T.Buffer[64, "int32"]) -> None: data_io_local = T.alloc_buffer([64], dtype="int32", scope="local") data_io_global = T.alloc_buffer([64], dtype="int32") data_io_global_1 = T.alloc_buffer([64], dtype="int32") @@ -976,16 +976,16 @@ def test_inplace_cache_read(): verify_trace_roundtrip(sch=sch, mod=inplace_func) -def test_inplace_cache_buffer(): - # cache buffer could introduce WAR, which is expected but stage pipeline property changes +def test_cache_inplace(): + # cache_inplace could introduce WAR, which is expected but stage pipeline property changes debug_mask = tvm.tir.schedule.state.ScheduleDebugMask.VERIFY_SREF_TREE sch = tvm.tir.Schedule(inplace_call, debug_mask=debug_mask) block = sch.get_block("ext_call") - blocks = sch.cache_buffer(block, 0, "local") + blocks = sch.cache_inplace(block, 0, "local") block = sch.cache_read(blocks[0], 0, "global", [blocks[0]]) block = sch.cache_write(blocks[1], 0, "global") - tvm.ir.assert_structural_equal(cache_buffer_inplace, sch.mod["main"]) + tvm.ir.assert_structural_equal(cache_inplace_buffer, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=inplace_call, debug_mask=debug_mask)