From 1c8c2454746297cc850007cf29ca0534219e7466 Mon Sep 17 00:00:00 2001 From: Min Chen Date: Thu, 20 Oct 2022 10:07:21 +0000 Subject: [PATCH] [TIR][Schedule] Add cache_index to precompute index of buffer load --- include/tvm/arith/int_set.h | 9 + include/tvm/tir/schedule/schedule.h | 8 + python/tvm/tir/schedule/schedule.py | 86 ++++ src/tir/schedule/concrete_schedule.cc | 13 + src/tir/schedule/concrete_schedule.h | 1 + src/tir/schedule/primitive.h | 9 + src/tir/schedule/primitive/cache_index.cc | 484 ++++++++++++++++++ src/tir/schedule/schedule.cc | 2 + src/tir/schedule/traced_schedule.cc | 14 + src/tir/schedule/traced_schedule.h | 1 + .../unittest/test_tir_schedule_cache_index.py | 78 +++ 11 files changed, 705 insertions(+) create mode 100644 src/tir/schedule/primitive/cache_index.cc create mode 100644 tests/python/unittest/test_tir_schedule_cache_index.py diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index 5ef7108d9797..60d7c53d28e8 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -169,6 +169,15 @@ Map ConvertDomMap(const std::unordered_map& * \return An integer set that can cover all the possible values of e. */ IntSet EvalSet(PrimExpr e, const Map& dom_map); +/*! + * \brief Find an symbolic integer set that contains all possible values of + * e given the domain of each variables. + * + * \param e The expression to be evaluated. + * \param dom_map The domain of each variable. + * \return An integer set that can cover all the possible values of e. + */ +IntSet EvalSet(PrimExpr e, const Map& dom_map); /*! * \brief Same as EvalSet, but takes unordered_map * diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 9ec2841ebd5e..3394e37070ff 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -413,6 +413,14 @@ class ScheduleNode : public runtime::Object { */ virtual Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) = 0; + /*! + * \brief Create a block to cache precomputed index for later use. + * if there is no index computation, keep unchanged. + * \param block_rv The target block + * \param buffer_index The index of the target buffer in block's read region + * \return The cache stage blocks. + */ + virtual Array CacheIndex(const BlockRV& block_rv, int buffer_index) = 0; /*! * \brief Create a block that read/write a buffer region into a read/write cache with reindexing. * The layout of the cache will be the same as by the iterators of the block that reads/writes the diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 4814271f4023..6c620045e90d 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1278,6 +1278,92 @@ def cache_inplace(data_io: T.Buffer[64, "int32"]) -> None: self, block, read_buffer_index, storage_scope ) + @type_checked + def cache_index( + self, block: Union[BlockRV, str], buffer_index: Union[int, str, Buffer] + ) -> List[BlockRV]: + """Create a block to cache precomputed index for later use. + if there is no index computation, keep unchanged. + + Parameters + ---------- + block : Union[BlockRV, str] + The target block operates on the target buffer. + + buffer_index: int + The index of the target buffer in block's read region + + + Returns + ------- + cached_blocks : List[BlockRV] + The blocks of the stage writing the cache buffers + + Examples + -------- + Before cache_inplace, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def resize(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (1, 3, 40, 40)) + B = T.match_buffer(b, (1, 3, 80, 80)) + for i0, i1, i2, i3 in T.grid(1, 3, 80, 80): + with T.block("A"): + n, c, vi, vj = T.axis.remap("SSSS", [i0, i1, i2, i3]) + B[n, c, vi, vj] = A[n, c, vi//4 + vj//4, vj//2] + + Create the schedule and cache_index: + + .. code-block:: python + + sch = tir.Schedule(resize) + block_a = sch.get_block("A") + sch.cache_index(block_a, 0) + print(sch.mod["main"].script()) + + After applying cache_index, the IR becomes: + + .. code-block:: python + + @T.prim_func + def resize_cache_index( + A: T.Buffer[(1, 3, 40, 40), "float32"], B: T.Buffer[(1, 3, 80, 80), "float32"] + ) -> None: + index_var_0 = T.alloc_buffer([80, 80], dtype="int32", strides=[1]) + index_var_1 = T.alloc_buffer([80], dtype="int32", strides=[1]) + for ax0, ax1 in T.grid(80, 80): + with T.block("index_0"): + v0 = T.axis.spatial(80, ax0) + v1 = T.axis.spatial(80, ax1) + T.reads() + T.writes(index_var_0[v0, v1]) + index_var_0[v0, v1] = v0 // 4 + v1 // 4 + for ax0 in T.serial(80): + with T.block("index_1"): + v0 = T.axis.spatial(80, ax0) + T.reads() + T.writes(index_var_1[v0]) + index_var_1[v0] = v0 // 2 + for i0, i1, i2, i3 in T.grid(1, 3, 80, 80): + with T.block("A"): + n, c, vi, vj = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(A[n, c, vi // 4 + vj // 4, vj // 2]) + T.writes(B[n, c, vi, vj]) + B[n, c, vi, vj] = A[n, c, index_var_0[vi, vj], index_var_1[vj]] + + """ + block = self._normalize_block_arg(block) + + if not isinstance(buffer_index, int): + _, buffer_index, _ = self._normalize_buffer_arg( + block, buffer_index, required_buffer_type="read" + ) + return _ffi_api.ScheduleCacheIndex( # type: ignore # pylint: disable=no-member + self, block, buffer_index + ) + @type_checked def reindex( self, diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 3960087cf745..7144ba8ae1f5 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -574,6 +574,19 @@ Array ConcreteScheduleNode::CacheInplace(const BlockRV& block_rv, int w return return_blocks; } +Array ConcreteScheduleNode::CacheIndex(const BlockRV& block_rv, int buffer_index) { + Array result; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::CacheIndex(state_, this->GetSRef(block_rv), buffer_index); + TVM_TIR_SCHEDULE_END("cache-index", this->error_render_level_); + this->state_->DebugVerify(); + Array return_blocks; + for (const StmtSRef& blockrv : result) { + return_blocks.push_back(CreateRV(blockrv)); + } + return return_blocks; +} + BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type) { StmtSRef result{nullptr}; diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index bfdc082d4ce6..384b1ce2425f 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -118,6 +118,7 @@ class ConcreteScheduleNode : public ScheduleNode { const String& storage_scope) override; Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) override; + Array CacheIndex(const BlockRV& block_rv, int write_buffer_index) override; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type) override; /******** Schedule: Compute location ********/ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 88331fb5b9d3..8e5ab91b8e7c 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -279,6 +279,15 @@ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int */ TVM_DLL Array CacheInplace(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index, const String& storage_scope); +/*! + * \brief Create a block to cache precomputed index for later use. + * if there is no index computation, keep unchanged. + * \param block_sref The target block + * \param buffer_index The index of the target buffer in block's read region, + * \return The cache stage block. + */ +TVM_DLL Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, + int buffer_index); /*! *! * \brief Create a block that read/write a buffer region into a read/write cache with reindexing. diff --git a/src/tir/schedule/primitive/cache_index.cc b/src/tir/schedule/primitive/cache_index.cc new file mode 100644 index 000000000000..ba58f81038cb --- /dev/null +++ b/src/tir/schedule/primitive/cache_index.cc @@ -0,0 +1,484 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { + +/******** Helper Functions/Classes ********/ + +/*! \brief The auxiliary info used for the insertion point and content of the cache stage. */ +struct IndexInfo { + /*! \brief The target buffer to cache the index. */ + Buffer target_buffer; + /*! \brief The cache buffer to store the precomputed index */ + std::vector cache_buffer; + /*! \brief The expr to be precomputed */ + std::vector index_exprs; + /*! \brief The range of the loop vars relating to index computation */ + Map range_map; + /*! \brief The binding table of the block var and the loop var */ + Map var_binding; + /*! \brief The block var of the target block */ + std::vector> origin_block_vars; + /*! \brief The index to insert the cache stage. */ + size_t loc_pos; + /*! \brief The cache stage to be inserted. */ + Stmt cache_stage; + /*! \brief The map used for ScheduleStateNode::Replace. */ + Map block_reuse; +}; + +/*! + * \brief Determine the data type base on the integer range. + * \param range The range of the integer. + * \returns A data type that covers the input range. + */ +DataType DetermineDatatype(const arith::IntSet& range) { + arith::Analyzer ana; + if (ana.CanProve(range.min() >= INT32_MIN && range.max() <= INT32_MAX)) { + return DataType::Int(32); + } else { + ICHECK(ana.CanProve(range.min() >= make_const(DataType::Int(64), INT64_MIN) && + range.max() <= make_const(DataType::Int(64), INT64_MAX))); + return DataType::Int(64); + } +} + +/*! \brief Collect the index info to be cached */ +class IndexInfoCollector : public StmtExprVisitor { + public: + /*! + * \brief Collect the index info for cache_index and write into the IndexInfo + * \param self The state of the schedule \param block_sref The sref of the target + * block of the target buffer being applied cache_index \param scope_sref The sref + * of the scope block of the target block \param info The index info. + */ + static void Collect(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_sref, IndexInfo* info) { + IndexInfoCollector collector(self, block_sref, scope_sref, info->target_buffer); + collector(GetRef(scope_sref->stmt)); + // info->loc_sref = collector.loc_sref_; + info->loc_pos = collector.loc_pos_; + info->index_exprs = collector.exprs_; + info->range_map = collector.range_map_; + } + + private: + /*! + * \brief Constructor + * \param self The state of the schedule + * \param block_sref The sref of the target block of the buffer being applied cache_index + * \param scope_sref The sref of the scope block of the target block + * \param buffer The target buffer to cache the indexs + */ + IndexInfoCollector(const ScheduleState self, const StmtSRef& block_sref, + const StmtSRef& scope_sref, const Buffer& buffer) + : self_(self), block_sref_(block_sref), scope_sref_(scope_sref), buffer_(buffer) {} + + void VisitStmt_(const SeqStmtNode* seq_stmt) final { + for (size_t i = 0; i < seq_stmt->size(); ++i) { + if (loc_pos_ != -1) { + break; + } + VisitStmt(seq_stmt->seq[i]); + // `pos` can be assigned only once when we visited `block_sref` + if (visited_block_ && loc_pos_ == -1 && update_seq_pos_) { + // The offset of insert position from the block + loc_pos_ = i; + return; + } + } + } + + void VisitStmt_(const BlockNode* block) final { + // Only visit the target's parent block + StmtVisitor::VisitStmt_(block); + if (block == scope_sref_->stmt) { + // The block vistied is the current parent scope + // Handling cases when no SeqStmt in the scope + if (visited_block_ && loc_pos_ == -1) { + loc_pos_ = 0; + } + } else if (block_sref_->stmt == block) { + visited_block_ = true; + } + // Update seq pos only at top scope + if (visited_block_ && self_->stmt2ref.at(block)->parent == scope_sref_.get()) { + update_seq_pos_ = true; + } + } + + void VisitStmt_(const ForNode* loop) final { + range_map_.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + StmtVisitor::VisitStmt_(loop); + // Update seq pos only at top scope + if (visited_block_ && self_->stmt2ref.at(loop)->parent == scope_sref_.get()) { + update_seq_pos_ = true; + } + } + + void VisitExpr_(const BufferLoadNode* load) final { + if (load->buffer.same_as(buffer_)) { + for (const PrimExpr& it : load->indices) { + if (!it->IsInstance()) { + exprs_.push_back(it); + } + } + } + ExprVisitor::VisitExpr_(load); + } + + /*! \brief The schedule class */ + const ScheduleState self_; + /*! \brief The target block that read the target buffer */ + const StmtSRef& block_sref_; + /*! \brief The parent scope of the target block */ + const StmtSRef& scope_sref_; + /*! \brief The target buffer to cache the index */ + const Buffer& buffer_; + /*! \brief The calculation expr to be precomputed */ + std::vector exprs_; + /*! \brief The flag whether we have visited the target block */ + bool visited_block_{false}; + /*! \brief The index to insert the cache_index stage */ + int loc_pos_{-1}; + /*! \brief The flag indicating the right scope to update seq pos */ + bool update_seq_pos_{false}; + /*! \brief Record the ranges of iter vars */ + Map range_map_; +}; + +/*! + * \brief Create a loop nest that writes precomputed index into index buffer. + * \param cache_region The cached copy region. + * \param info The cache stage information, which will be updated in the function. + * \param storage_scope The storage scope of the cached buffer (only used in naming here) + * \returns A block indicating the body of the loop nesting. + */ +Array MakeIndexCacheStage(IndexInfo* info) { + Array blocks; + Array bodies; + bodies.reserve(info->index_exprs.size()); + info->cache_buffer.reserve(info->index_exprs.size()); + const String& storage_scope = info->target_buffer.scope(); + + // For each index calculation, create a block to pre-compute. + for (size_t expr_index = 0; expr_index < info->index_exprs.size(); expr_index++) { + const PrimExpr& index_expr = info->index_exprs[expr_index]; + + // Collect the block vars in original index computation + info->origin_block_vars.push_back({}); + PostOrderVisit(index_expr, [&info, &expr_index](const ObjectRef& node) { + if (node->IsInstance()) { + Var iter_var = Downcast(node); + const Array& origin_block_var = info->origin_block_vars[expr_index]; + auto find_result = std::find_if(origin_block_var.begin(), origin_block_var.end(), + [&](Var it) { return it.get() == iter_var.get(); }); + if (find_result == origin_block_var.end()) { + info->origin_block_vars[expr_index].push_back(iter_var); + } + } + }); + + // Collect the loop vars corresponding to collected block vars, + // which will be used to create new loop vars + std::vector iter_vars; + for (const Var& it : info->origin_block_vars[expr_index]) { + PostOrderVisit(info->var_binding.at(it), [&info, &iter_vars](const ObjectRef& node) { + if (node->IsInstance()) { + Var iter_var = Downcast(node); + if (std::find_if(iter_vars.begin(), iter_vars.end(), + [&](Var it) { return it.get() == iter_var.get(); }) == iter_vars.end()) { + iter_vars.push_back(iter_var); + } + } + }); + } + + // Inference the shape and create cache buffer + arith::IntSet val_range = + arith::EvalSet(Substitute(index_expr, info->var_binding), arith::AsIntSet(info->range_map)); + DataType data_type = DetermineDatatype(val_range); + Var index_buffer_var("index_var_" + std::to_string(expr_index), + PointerType(PrimType(data_type), storage_scope)); + Array buffer_shape; + for (const Var& it : info->origin_block_vars[expr_index]) { + buffer_shape.push_back( + arith::EvalSet(info->var_binding.at(it), arith::AsIntSet(info->range_map)).max() + 1); + } + info->cache_buffer.push_back(Buffer(index_buffer_var, data_type, buffer_shape, {1}, {0}, + index_buffer_var->name_hint, 0, 0, kDefault)); + + // Create loop vars and block vars' binding_value + std::vector loop_vars; + Map replace_table; + for (const Var& it : iter_vars) { + DataType data_type = DetermineDatatype(arith::IntSet::FromRange(info->range_map.at(it))); + Var loop_var("ax" + std::to_string(replace_table.size()), data_type); + loop_vars.push_back(loop_var); + replace_table.Set(it, loop_var); + } + // Create iter_values from the original block. + std::vector iter_values; + for (const Var& it : info->origin_block_vars[expr_index]) { + iter_values.push_back(Substitute(info->var_binding.at(it), replace_table)); + } + // block variables + Array block_vars; + // block access region for write buffers + Region access_region; + // indices used in block body + Array access_indices; + Map block_var_map; + // Create block vars, block's accessed region and accessing indices + for (size_t i = 0; i < info->origin_block_vars[expr_index].size(); i++) { + const Var& block_var = info->origin_block_vars[expr_index][i]; + Var var("v" + std::to_string(access_indices.size()), block_var.dtype()); + Range range = Range::FromMinExtent(make_zero(block_var.dtype()), + info->range_map.at(iter_vars[i])->extent); + block_vars.push_back(IterVar(/*dom=*/range, + /*var=*/var, + /*IterVarType=*/kDataPar)); + + access_indices.push_back(var); + access_region.push_back(Range::FromMinExtent(var, make_const(var.dtype(), 1))); + block_var_map.Set(block_var, var); + } + + // Create the index computing block + PrimExpr new_expr = Substitute(index_expr, block_var_map); + Block block( + /*iter_vars=*/std::move(block_vars), + /*reads=*/{}, + /*writes=*/{BufferRegion(info->cache_buffer[expr_index], access_region)}, + /*name_hint=*/"index_" + std::to_string(expr_index), + /*body=*/ + BufferStore(info->cache_buffer[expr_index], new_expr, access_indices), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/{}); + blocks.push_back(block); + // Create the block realize node + Stmt body = BlockRealize(/*values=*/iter_values, + /*predicate=*/const_true(), + /*block=*/block); + // Create surrounding loops + for (size_t i = loop_vars.size(); i >= 1; --i) { + body = For(/*loop_var=*/loop_vars[i - 1], + /*min=*/0, + /*extent=*/info->range_map.at(iter_vars[i - 1])->extent, + /*kind=*/ForKind::kSerial, + /*body=*/body); + } + bodies.push_back(body); + } + + info->cache_stage = SeqStmt(bodies); + return blocks; +} + +/*! + * \brief Insert the cache stages into the specific position + * \param stmt A sequence of statements or a single statement that the new stage is inserted in + * \param pos The position where the cache stage is inserted + * \param stage The stage to be inserted + * \return A SeqStmt, the result after insertion + */ +Stmt InsertIndexStage(const Stmt& stmt, int pos, const Stmt& stage) { + if (const auto* seq_stmt = stmt.as()) { + ObjectPtr result = make_object(*seq_stmt); + result->seq.insert(result->seq.begin() + pos, stage); + return SeqStmt(result); + } + if (pos == 0) { + return SeqStmt::Flatten>({stage, stmt}); + } + ICHECK_EQ(pos, 1); + return SeqStmt::Flatten>({stmt, stage}); +} + +/*! \brief Mutator for CacheIndex. */ +class CacheIndexRewriter : public StmtExprMutator { + public: + /*! + * \brief Rewrite the AST and add stages of writting precomputed index + * \param scope_sref The parent scope of this mutation + * \param info The index information + * \return The new AST rooting at the original parent scope + */ + static Stmt Rewrite(const StmtSRef& scope_sref, IndexInfo* info) { + CacheIndexRewriter rewriter(scope_sref, info); + return rewriter(GetRef(scope_sref->stmt)); + } + + private: + explicit CacheIndexRewriter(const StmtSRef& scope_sref, IndexInfo* info) + : scope_sref_(scope_sref), info_(info) { + cache_indices_.reserve(info_->origin_block_vars.size()); + for (const Array& group_it : info_->origin_block_vars) { + cache_indices_.push_back({}); + for (const Var& it : group_it) { + cache_indices_.back().push_back(it); + } + } + } + + Stmt VisitStmt_(const BlockNode* block) final { + Block old_stmt = GetRef(block); + // Mutate the body + Block stmt = Downcast(StmtMutator::VisitStmt_(block)); + + // Check if it is the block corresponding to the parent scope + if (block == scope_sref_->stmt) { + // If so, put buffer allocation and insert cache stages on the parent scope + ObjectPtr n = make_object(*stmt.as()); + n->body = InsertIndexStage(n->body, info_->loc_pos, info_->cache_stage); + for (const Buffer& it : info_->cache_buffer) { + n->alloc_buffers.push_back(it); + } + stmt = Block(n); + } + info_->block_reuse.Set(old_stmt, stmt); + return std::move(stmt); + } + + PrimExpr VisitExpr_(const BufferLoadNode* load) final { + if (load->buffer.same_as(info_->target_buffer)) { + // Rewrite the target buffer load + Array new_indices; + for (const PrimExpr& index : load->indices) { + auto it = std::find_if(info_->index_exprs.begin(), info_->index_exprs.end(), + [&](PrimExpr& e) { return e.get() == index.get(); }); + if (it == info_->index_exprs.end()) { + new_indices.push_back(index); + } else { + // Replace load index with cached index + auto offset = std::distance(info_->index_exprs.begin(), it); + new_indices.push_back(BufferLoad(info_->cache_buffer[offset], cache_indices_[offset])); + } + } + return BufferLoad(load->buffer, new_indices); + } + return ExprMutator::VisitExpr_(load); + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + + private: + /*! \brief The parent scope of the insertion */ + const StmtSRef& scope_sref_; + /*! \brief The info for inserting cache stage */ + IndexInfo* info_; + /*! \brief The indices for the cache buffer */ + std::vector> cache_indices_; +}; + +Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index) { + /*! + * Check: + * - The index is in the array of block reading region + * + * Mutate: + * - Allocate new cache buffers under the current scope. + * - Precompute the index and store it in cache buffers. + */ + + // Step 0. Checking index, getting the target buffer and the parent scope + IndexInfo info; + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + info.target_buffer = + GetNthAccessBuffer(self, GetRef(block), buffer_index, BufferIndexType::kRead); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); + + // Step 1. Collect the indexing info of target buffer. + IndexInfoCollector::Collect(self, block_sref, scope_sref, &info); + + // Step 2. Create cache stages and rewrite the stmt. + BlockRealize realize = GetBlockRealize(self, block_sref); + info.var_binding = GetBindings(realize); + Array cache_stages = MakeIndexCacheStage(&info); + Stmt new_scope = CacheIndexRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info); + + bool old_stage_pipeline = self->block_info[block_sref].scope->stage_pipeline; + + // Step 3. Replacing and updating flags. + self->Replace(scope_sref, new_scope, info.block_reuse); + Array result_block_srefs; + for (const Block& it : cache_stages) { + StmtSRef result_block_sref = self->stmt2ref.at(it.get()); + result_block_srefs.push_back(result_block_sref); + BlockInfo& block_info = self->block_info[result_block_sref]; + + bool affine_binding = false; + if (result_block_sref->parent == nullptr) { + affine_binding = true; + } else { + arith::Analyzer analyzer; + StmtSRef parent_sref = GetRef(result_block_sref->parent); + affine_binding = IsAffineBinding(/*realize=*/GetBlockRealize(self, result_block_sref), + /*loop_var_ranges=*/LoopDomainOfSRefTreePath(parent_sref), + /*analyzer=*/&analyzer); + } + + block_info.affine_binding = affine_binding; + block_info.region_cover = true; + block_info.scope->stage_pipeline = old_stage_pipeline; + } + + return result_block_srefs; +} + +/******** InstructionKind Registration ********/ + +struct CacheIndexTraits : public UnpackedInstTraits { + static constexpr const char* kName = "CacheIndex"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer buffer_index) { + return sch->CacheIndex(block, buffer_index->value); + } + + static String UnpackedAsPython(Array outputs, String block, Integer buffer_index) { + PythonAPICall py("cache_index"); + py.Input("block", block); + py.Input("buffer_index", buffer_index->value); + py.OutputList(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(CacheIndexTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 280d0af92a8c..6425ae0766ae 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -181,6 +181,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") .set_body_method(&ScheduleNode::CacheWrite); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheInplace") .set_body_method(&ScheduleNode::CacheInplace); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheIndex") + .set_body_method(&ScheduleNode::CacheIndex); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReIndex") .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, int buffer_index_type) { diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index b67b008feda4..f2ad27fb6962 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -323,6 +323,20 @@ Array TracedScheduleNode::CacheInplace(const BlockRV& block_rv, int rea return result; } +Array TracedScheduleNode::CacheIndex(const BlockRV& block_rv, int buffer_index) { + Array result = ConcreteScheduleNode::CacheIndex(block_rv, buffer_index); + Array outputs; + for (const BlockRV& r : result) { + outputs.push_back(r); + } + static const InstructionKind& kind = InstructionKind::Get("CacheIndex"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(buffer_index)}, + /*outputs=*/outputs)); + return result; +} + BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type) { BlockRV result = ConcreteScheduleNode::ReIndex(block_rv, buffer_index, buffer_index_type); diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 016de60726b9..06128c1a6ebc 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -80,6 +80,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { const String& storage_scope) final; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type) final; + Array CacheIndex(const BlockRV& block_rv, int buffer_index) final; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, int index = -1) final; diff --git a/tests/python/unittest/test_tir_schedule_cache_index.py b/tests/python/unittest/test_tir_schedule_cache_index.py new file mode 100644 index 000000000000..0c2882d1b617 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_cache_index.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# pylint: disable=no-member,invalid-name,unused-variable + +########## Function before schedule ########## + + +@T.prim_func +def resize(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (1, 3, 40, 40)) + B = T.match_buffer(b, (1, 3, 80, 80)) + for i0, i1, i2, i3 in T.grid(1, 3, 80, 80): + with T.block("A"): + n, c, vi, vj = T.axis.remap("SSSS", [i0, i1, i2, i3]) + B[n, c, vi, vj] = A[n, c, vi // 4 + vj // 4, vj // 2] + + +@T.prim_func +def resize_cache_index( + A: T.Buffer[(1, 3, 40, 40), "float32"], B: T.Buffer[(1, 3, 80, 80), "float32"] +) -> None: + index_var_0 = T.alloc_buffer([80, 80], dtype="int32", strides=[1]) + index_var_1 = T.alloc_buffer([80], dtype="int32", strides=[1]) + for ax0, ax1 in T.grid(80, 80): + with T.block("index_0"): + v0 = T.axis.spatial(80, ax0) + v1 = T.axis.spatial(80, ax1) + T.reads() + T.writes(index_var_0[v0, v1]) + index_var_0[v0, v1] = v0 // 4 + v1 // 4 + for ax0 in T.serial(80): + with T.block("index_1"): + v0 = T.axis.spatial(80, ax0) + T.reads() + T.writes(index_var_1[v0]) + index_var_1[v0] = v0 // 2 + for i0, i1, i2, i3 in T.grid(1, 3, 80, 80): + with T.block("A"): + n, c, vi, vj = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(A[n, c, vi // 4 + vj // 4, vj // 2]) + T.writes(B[n, c, vi, vj]) + B[n, c, vi, vj] = A[n, c, index_var_0[vi, vj], index_var_1[vj]] + + +def test_inplace_cache_read(): + sch = tvm.tir.Schedule(resize, debug_mask="all") + block = sch.get_block("A") + sch.cache_index(block, 0) + tvm.ir.assert_structural_equal(resize_cache_index, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=resize) + + +if __name__ == "__main__": + tvm.testing.main()