diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 8b22c173a3d8..288601d1cccc 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -419,10 +419,12 @@ class ScheduleNode : public runtime::Object { * \brief Create a block to cache precomputed index for later use. * if there is no index computation, keep unchanged. * \param block_rv The target block - * \param buffer_index The index of the target buffer in block's read region + * \param storage_scope The storage scope of cached block + * \param cse_thresh The repeat threshold that determines a common sub expr * \return The cache stage blocks. */ - virtual Array CacheIndex(const BlockRV& block_rv, int buffer_index) = 0; + virtual Array CacheIndex(const BlockRV& block_rv, const String& storage_scope, + int cse_thresh) = 0; /*! * \brief Create a block that read/write a buffer region into a read/write cache with reindexing. * The layout of the cache will be the same as by the iterators of the block that reads/writes the diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 64aba0e029fe..6a71e5872fcd 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1294,7 +1294,10 @@ def cache_inplace(data_io: T.Buffer[64, "int32"]) -> None: @type_checked def cache_index( - self, block: Union[BlockRV, str], buffer_index: Union[int, str, Buffer] + self, + block: Union[BlockRV, str], + storage_scope: str, + cse_thresh: int = 0, ) -> List[BlockRV]: """Create a block to cache precomputed index for later use. if there is no index computation, keep unchanged. @@ -1304,8 +1307,12 @@ def cache_index( block : Union[BlockRV, str] The target block operates on the target buffer. - buffer_index: int - The index of the target buffer in block's read region + storage_scope: str + The storage scope of cached block. + + cse_thresh: int + The repeat threshold that determines a common sub expr, + default 0 means cache all index computation. Returns @@ -1334,7 +1341,7 @@ def resize(a: T.handle, b: T.handle) -> None: sch = tir.Schedule(resize) block_a = sch.get_block("A") - sch.cache_index(block_a, 0) + sch.cache_index(block_a, "global", 1) print(sch.mod["main"].script()) After applying cache_index, the IR becomes: @@ -1370,12 +1377,8 @@ def resize_cache_index( """ block = self._normalize_block_arg(block) - if not isinstance(buffer_index, int): - _, buffer_index, _ = self._normalize_buffer_arg( - block, buffer_index, required_buffer_type="read" - ) return _ffi_api.ScheduleCacheIndex( # type: ignore # pylint: disable=no-member - self, block, buffer_index + self, block, storage_scope, cse_thresh ) @type_checked diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 163c72eb0777..91ca0f141766 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -581,10 +581,11 @@ Array ConcreteScheduleNode::CacheInplace(const BlockRV& block_rv, int w return return_blocks; } -Array ConcreteScheduleNode::CacheIndex(const BlockRV& block_rv, int buffer_index) { +Array ConcreteScheduleNode::CacheIndex(const BlockRV& block_rv, + const String& storage_scope, int cse_thresh) { Array result; TVM_TIR_SCHEDULE_BEGIN(); - result = tir::CacheIndex(state_, this->GetSRef(block_rv), buffer_index); + result = tir::CacheIndex(state_, this->GetSRef(block_rv), storage_scope, cse_thresh); TVM_TIR_SCHEDULE_END("cache-index", this->error_render_level_); this->state_->DebugVerify(); Array return_blocks; diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 899775f2a15d..95d5fe9c2e44 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -118,7 +118,8 @@ class ConcreteScheduleNode : public ScheduleNode { const Array consumer_blocks = {}) override; Array CacheInplace(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope) override; - Array CacheIndex(const BlockRV& block_rv, int write_buffer_index) override; + Array CacheIndex(const BlockRV& block_rv, const String& storage_scope, + int cse_thresh) override; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type) override; /******** Schedule: Compute location ********/ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 9e7f77f55ea5..dbc4e235965c 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -285,11 +285,12 @@ TVM_DLL Array CacheInplace(ScheduleState self, const StmtSRef& block_s * \brief Create a block to cache precomputed index for later use. * if there is no index computation, keep unchanged. * \param block_sref The target block - * \param buffer_index The index of the target buffer in block's read region, + * \param storage_scope The storage scope of cached block + * \param cse_thresh The repeat threshold that determines a common sub expr * \return The cache stage block. */ TVM_DLL Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, - int buffer_index); + const String& storage_scope, int cse_thresh); /*! *! * \brief Create a block that read/write a buffer region into a read/write cache with reindexing. diff --git a/src/tir/schedule/primitive/cache_index.cc b/src/tir/schedule/primitive/cache_index.cc index c6f845541dd2..0316feefd5de 100644 --- a/src/tir/schedule/primitive/cache_index.cc +++ b/src/tir/schedule/primitive/cache_index.cc @@ -18,6 +18,8 @@ */ #include +#include "../../transforms/common_subexpr_elim_tools.h" +#include "../../transforms/replace_selected_expr.h" #include "../utils.h" namespace tvm { @@ -27,8 +29,10 @@ namespace tir { /*! \brief The auxiliary info used for the insertion point and content of the cache stage. */ struct IndexInfo { - /*! \brief The target buffer to cache the index. */ - Buffer target_buffer; + /*! \brief The target block to perform cache_index */ + StmtSRef target_block; + /*! \brief Record the common subexpr extract threshold */ + size_t cse_thresh; /*! \brief The cache buffer to store the precomputed index */ std::vector cache_buffer; /*! \brief The expr to be precomputed */ @@ -74,9 +78,8 @@ class IndexInfoCollector : public StmtExprVisitor { */ static void Collect(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_sref, IndexInfo* info) { - IndexInfoCollector collector(self, block_sref, scope_sref, info->target_buffer); + IndexInfoCollector collector(self, block_sref, scope_sref, info->cse_thresh); collector(GetRef(scope_sref->stmt)); - // info->loc_sref = collector.loc_sref_; info->loc_pos = collector.loc_pos_; info->index_exprs = collector.exprs_; info->range_map = collector.range_map_; @@ -88,11 +91,11 @@ class IndexInfoCollector : public StmtExprVisitor { * \param self The state of the schedule * \param block_sref The sref of the target block of the buffer being applied cache_index * \param scope_sref The sref of the scope block of the target block - * \param buffer The target buffer to cache the indexs + * \param cse_thresh The repeat threshold that determines a common subexpr */ IndexInfoCollector(const ScheduleState self, const StmtSRef& block_sref, - const StmtSRef& scope_sref, const Buffer& buffer) - : self_(self), block_sref_(block_sref), scope_sref_(scope_sref), buffer_(buffer) {} + const StmtSRef& scope_sref, int cse_thresh) + : self_(self), block_sref_(block_sref), scope_sref_(scope_sref), cse_thresh_(cse_thresh) {} void VisitStmt_(const SeqStmtNode* seq_stmt) final { for (size_t i = 0; i < seq_stmt->size(); ++i) { @@ -110,8 +113,9 @@ class IndexInfoCollector : public StmtExprVisitor { } void VisitStmt_(const BlockNode* block) final { - // Only visit the target's parent block + visiting_target_block = static_cast(block_sref_->stmt == block); StmtVisitor::VisitStmt_(block); + visiting_target_block = false; if (block == scope_sref_->stmt) { // The block vistied is the current parent scope // Handling cases when no SeqStmt in the scope @@ -136,15 +140,56 @@ class IndexInfoCollector : public StmtExprVisitor { } } - void VisitExpr_(const BufferLoadNode* load) final { - if (load->buffer.same_as(buffer_)) { - for (const PrimExpr& it : load->indices) { - if (!it->IsInstance()) { - exprs_.push_back(it); + void VisitStmt_(const BufferStoreNode* store) final { + // Only analyze the cache candidate for stores in target block + if (visiting_target_block) { + auto IsEligibleComputation = [](const PrimExpr& expr) { + return (SideEffect(expr) <= CallEffectKind::kPure && CalculateExprComplexity(expr) > 1 && + (expr.as() == nullptr) && (expr.as() == nullptr)); + }; + + // Analyze sub expr candidates + ComputationTable table_syntactic_comp_done_by_stmt = + ComputationsDoneBy::GetComputationsDoneBy(GetRef(store), IsEligibleComputation, + [](const PrimExpr& expr) { return true; }); + std::vector> semantic_comp_done_by_stmt = + SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt, true); + + // Analyze the sub expr of a candidate whose repeat time is under cse_thresh_ + for (size_t i = 0; i < semantic_comp_done_by_stmt.size(); i++) { + std::pair& computation_and_nb = semantic_comp_done_by_stmt[i]; + if (computation_and_nb.second < cse_thresh_) { + std::vector direct_subexprs = DirectSubexpr::GetDirectSubexpressions( + computation_and_nb.first, IsEligibleComputation, + [](const PrimExpr& expr) { return true; }); + InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_stmt, direct_subexprs, + true, computation_and_nb.second); } } + + // Record the final sub expr with repeat time greater than cse_thresh_ + // In order to make the result stable, sort it by post order and then by complexity + PostOrderVisit(store->value, [&semantic_comp_done_by_stmt, this](const ObjectRef& node) { + if (node->IsInstance()) { + PrimExpr this_expr = Downcast(node); + for (auto& it : semantic_comp_done_by_stmt) { + if (it.second >= this->cse_thresh_ && EquivalentTerms(this_expr, it.first, true)) { + auto find_result = + std::find_if(this->exprs_.begin(), this->exprs_.end(), + [&](PrimExpr expr) { return expr.get() == it.first.get(); }); + if (find_result == this->exprs_.end()) { + this->exprs_.push_back(it.first); + } + } + } + } + }); + auto cmp = [&](const PrimExpr& lhs, const PrimExpr& rhs) -> bool { + return CalculateExprComplexity(lhs) > CalculateExprComplexity(rhs); + }; + std::stable_sort(exprs_.begin(), exprs_.end(), cmp); } - ExprVisitor::VisitExpr_(load); + StmtVisitor::VisitStmt_(store); } /*! \brief The schedule class */ @@ -153,12 +198,14 @@ class IndexInfoCollector : public StmtExprVisitor { const StmtSRef& block_sref_; /*! \brief The parent scope of the target block */ const StmtSRef& scope_sref_; - /*! \brief The target buffer to cache the index */ - const Buffer& buffer_; + /*! \brief Record the common subexpr extract threshold */ + size_t cse_thresh_; /*! \brief The calculation expr to be precomputed */ std::vector exprs_; /*! \brief The flag whether we have visited the target block */ bool visited_block_{false}; + /*! \brief The flag indicating currently visiting target block */ + bool visiting_target_block{false}; /*! \brief The index to insert the cache_index stage */ int loc_pos_{-1}; /*! \brief The flag indicating the right scope to update seq pos */ @@ -169,17 +216,15 @@ class IndexInfoCollector : public StmtExprVisitor { /*! * \brief Create a loop nest that writes precomputed index into index buffer. - * \param cache_region The cached copy region. * \param info The cache stage information, which will be updated in the function. * \param storage_scope The storage scope of the cached buffer (only used in naming here) * \returns A block indicating the body of the loop nesting. */ -Array MakeIndexCacheStage(IndexInfo* info) { +Array MakeIndexCacheStage(IndexInfo* info, const String& storage_scope) { Array blocks; Array bodies; bodies.reserve(info->index_exprs.size()); info->cache_buffer.reserve(info->index_exprs.size()); - const String& storage_scope = info->target_buffer.scope(); // For each index calculation, create a block to pre-compute. for (size_t expr_index = 0; expr_index < info->index_exprs.size(); expr_index++) { @@ -214,10 +259,7 @@ Array MakeIndexCacheStage(IndexInfo* info) { }); } - // Inference the shape and create cache buffer - arith::IntSet val_range = - arith::EvalSet(Substitute(index_expr, info->var_binding), arith::AsIntSet(info->range_map)); - DataType data_type = DetermineDatatype(val_range); + DataType data_type = index_expr.dtype(); Var index_buffer_var("index_var_" + std::to_string(expr_index), PointerType(PrimType(data_type), storage_scope)); Array buffer_shape; @@ -346,7 +388,9 @@ class CacheIndexRewriter : public StmtExprMutator { Stmt VisitStmt_(const BlockNode* block) final { Block old_stmt = GetRef(block); // Mutate the body + visiting_target_block = static_cast(block == info_->target_block->stmt); Block stmt = Downcast(StmtMutator::VisitStmt_(block)); + visiting_target_block = false; // Check if it is the block corresponding to the parent scope if (block == scope_sref_->stmt) { @@ -362,24 +406,23 @@ class CacheIndexRewriter : public StmtExprMutator { return std::move(stmt); } - PrimExpr VisitExpr_(const BufferLoadNode* load) final { - if (load->buffer.same_as(info_->target_buffer)) { - // Rewrite the target buffer load - Array new_indices; - for (const PrimExpr& index : load->indices) { - auto it = std::find_if(info_->index_exprs.begin(), info_->index_exprs.end(), - [&](PrimExpr& e) { return e.get() == index.get(); }); - if (it == info_->index_exprs.end()) { - new_indices.push_back(index); - } else { - // Replace load index with cached index - auto offset = std::distance(info_->index_exprs.begin(), it); - new_indices.push_back(BufferLoad(info_->cache_buffer[offset], cache_indices_[offset])); - } + Stmt VisitStmt_(const BufferStoreNode* store) final { + Stmt ret_stmt = StmtMutator::VisitStmt_(store); + // Replace common sub expr for target block, with cached buffer load + if (visiting_target_block) { + for (size_t i = 0; i < info_->index_exprs.size(); i++) { + PrimExpr& computation = info_->index_exprs[i]; + std::function predicate_selector = + [computation](const PrimExpr& current_expr) { + return (EquivalentTerms(current_expr, computation, true)); + }; + BufferLoad load = BufferLoad(info_->cache_buffer[i], cache_indices_[i]); + ret_stmt = ReplaceSelectedExpr::ReplaceSelectedExprInStmt( + ret_stmt, predicate_selector, std::move(load), + [](const PrimExpr& expr) { return true; }); } - return BufferLoad(load->buffer, new_indices); } - return ExprMutator::VisitExpr_(load); + return ret_stmt; } PrimExpr VisitExpr_(const LoadNode* op) final { @@ -393,9 +436,12 @@ class CacheIndexRewriter : public StmtExprMutator { IndexInfo* info_; /*! \brief The indices for the cache buffer */ std::vector> cache_indices_; + /*! \brief Indicating whether cache stage is inserted, only do index replacement afterwards*/ + bool visiting_target_block{false}; }; -Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index) { +Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, + const String& storage_scope, int cse_thresh) { /*! * Check: * - The index is in the array of block reading region @@ -407,9 +453,9 @@ Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, int b // Step 0. Checking index, getting the target buffer and the parent scope IndexInfo info; - const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); - info.target_buffer = - GetNthAccessBuffer(self, GetRef(block), buffer_index, BufferIndexType::kRead); + info.target_block = block_sref; + CHECK_GE(cse_thresh, 0) << "cse_thresh should not be negative number"; + info.cse_thresh = cse_thresh; StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); // Step 1. Collect the indexing info of target buffer. @@ -418,7 +464,7 @@ Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, int b // Step 2. Create cache stages and rewrite the stmt. BlockRealize realize = GetBlockRealize(self, block_sref); info.var_binding = GetBindings(realize); - Array cache_stages = MakeIndexCacheStage(&info); + Array cache_stages = MakeIndexCacheStage(&info, storage_scope); Stmt new_scope = CacheIndexRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info); bool old_stage_pipeline = self->block_info[block_sref].scope->stage_pipeline; @@ -458,17 +504,20 @@ struct CacheIndexTraits : public UnpackedInstTraits { private: static constexpr size_t kNumInputs = 1; - static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumAttrs = 2; static constexpr size_t kNumDecisions = 0; - static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer buffer_index) { - return sch->CacheIndex(block, buffer_index->value); + static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block, String storage_scope, + Integer cse_thresh) { + return sch->CacheIndex(block, storage_scope, cse_thresh->value); } - static String UnpackedAsPython(Array outputs, String block, Integer buffer_index) { + static String UnpackedAsPython(Array outputs, String block, String storage_scope, + Integer cse_thresh) { PythonAPICall py("cache_index"); py.Input("block", block); - py.Input("buffer_index", buffer_index->value); + py.Input("storage_scope", storage_scope); + py.Input("cse_thresh", cse_thresh->value); py.OutputList(outputs); return py.Str(); } diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 70559608e789..3dc78074fcd6 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -325,8 +325,9 @@ Array TracedScheduleNode::CacheInplace(const BlockRV& block_rv, int rea return result; } -Array TracedScheduleNode::CacheIndex(const BlockRV& block_rv, int buffer_index) { - Array result = ConcreteScheduleNode::CacheIndex(block_rv, buffer_index); +Array TracedScheduleNode::CacheIndex(const BlockRV& block_rv, const String& storage_scope, + int cse_thresh) { + Array result = ConcreteScheduleNode::CacheIndex(block_rv, storage_scope, cse_thresh); Array outputs; for (const BlockRV& r : result) { outputs.push_back(r); @@ -334,7 +335,7 @@ Array TracedScheduleNode::CacheIndex(const BlockRV& block_rv, int buffe static const InstructionKind& kind = InstructionKind::Get("CacheIndex"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{block_rv}, - /*attrs=*/{Integer(buffer_index)}, + /*attrs=*/{storage_scope, Integer(cse_thresh)}, /*outputs=*/outputs)); return result; } diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index c54574e9c9ff..ee65c721ad9f 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -80,7 +80,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { const String& storage_scope) final; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type) final; - Array CacheIndex(const BlockRV& block_rv, int buffer_index) final; + Array CacheIndex(const BlockRV& block_rv, const String& storage_scope, + int cse_thresh) final; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, int index = -1) final; diff --git a/tests/python/unittest/test_tir_schedule_cache_index.py b/tests/python/unittest/test_tir_schedule_cache_index.py index 0c2882d1b617..d446249e018e 100644 --- a/tests/python/unittest/test_tir_schedule_cache_index.py +++ b/tests/python/unittest/test_tir_schedule_cache_index.py @@ -47,8 +47,7 @@ def resize_cache_index( index_var_1 = T.alloc_buffer([80], dtype="int32", strides=[1]) for ax0, ax1 in T.grid(80, 80): with T.block("index_0"): - v0 = T.axis.spatial(80, ax0) - v1 = T.axis.spatial(80, ax1) + v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads() T.writes(index_var_0[v0, v1]) index_var_0[v0, v1] = v0 // 4 + v1 // 4 @@ -66,13 +65,408 @@ def resize_cache_index( B[n, c, vi, vj] = A[n, c, index_var_0[vi, vj], index_var_1[vj]] -def test_inplace_cache_read(): +@T.prim_func +def bilinear_resize( + x: T.Buffer[(1, 3, 40, 40), "float16"], resize: T.Buffer[(1, 3, 80, 80), "float16"] +): + for i0, i1, i2, i3 in T.grid(1, 3, 80, 80): + with T.block("resize"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(x[i0_1, i1_1, 0:40, 0:40]) + T.writes(resize[i0_1, i1_1, i2_1, i3_1]) + resize[i0_1, i1_1, i2_1, i3_1] = T.Cast( + "float16", + ( + T.Cast( + "float32", + x[ + i0_1, + i1_1, + T.max( + T.min( + T.Cast( + "int32", + T.floor( + (T.Cast("float32", i2_1) + T.float32(0.5)) + * T.float32(0.5) + - T.float32(0.5), + dtype="float32", + ), + ), + 39, + ), + 0, + ), + T.max( + T.min( + T.Cast( + "int32", + T.floor( + (T.Cast("float32", i3_1) + T.float32(0.5)) + * T.float32(0.5) + - T.float32(0.5), + dtype="float32", + ), + ), + 39, + ), + 0, + ), + ], + ) + * ( + T.float32(1) + - ( + (T.Cast("float32", i3_1) + T.float32(0.5)) * T.float32(0.5) + - T.float32(0.5) + - T.Cast( + "float32", + T.Cast( + "int32", + T.floor( + (T.Cast("float32", i3_1) + T.float32(0.5)) * T.float32(0.5) + - T.float32(0.5), + dtype="float32", + ), + ), + ) + ) + ) + + T.Cast( + "float32", + x[ + i0_1, + i1_1, + T.max( + T.min( + T.Cast( + "int32", + T.floor( + (T.Cast("float32", i2_1) + T.float32(0.5)) + * T.float32(0.5) + - T.float32(0.5), + dtype="float32", + ), + ), + 39, + ), + 0, + ), + T.max( + T.min( + T.Cast( + "int32", + T.floor( + (T.Cast("float32", i3_1) + T.float32(0.5)) + * T.float32(0.5) + - T.float32(0.5), + dtype="float32", + ), + ) + + 1, + 39, + ), + 0, + ), + ], + ) + * ( + (T.Cast("float32", i3_1) + T.float32(0.5)) * T.float32(0.5) + - T.float32(0.5) + - T.Cast( + "float32", + T.Cast( + "int32", + T.floor( + (T.Cast("float32", i3_1) + T.float32(0.5)) * T.float32(0.5) + - T.float32(0.5), + dtype="float32", + ), + ), + ) + ) + ) + * ( + T.float32(1) + - ( + (T.Cast("float32", i2_1) + T.float32(0.5)) * T.float32(0.5) + - T.float32(0.5) + - T.Cast( + "float32", + T.Cast( + "int32", + T.floor( + (T.Cast("float32", i2_1) + T.float32(0.5)) * T.float32(0.5) + - T.float32(0.5), + dtype="float32", + ), + ), + ) + ) + ) + + ( + T.Cast( + "float32", + x[ + i0_1, + i1_1, + T.max( + T.min( + T.Cast( + "int32", + T.floor( + (T.Cast("float32", i2_1) + T.float32(0.5)) + * T.float32(0.5) + - T.float32(0.5), + dtype="float32", + ), + ) + + 1, + 39, + ), + 0, + ), + T.max( + T.min( + T.Cast( + "int32", + T.floor( + (T.Cast("float32", i3_1) + T.float32(0.5)) + * T.float32(0.5) + - T.float32(0.5), + dtype="float32", + ), + ), + 39, + ), + 0, + ), + ], + ) + * ( + T.float32(1) + - ( + (T.Cast("float32", i3_1) + T.float32(0.5)) * T.float32(0.5) + - T.float32(0.5) + - T.Cast( + "float32", + T.Cast( + "int32", + T.floor( + (T.Cast("float32", i3_1) + T.float32(0.5)) * T.float32(0.5) + - T.float32(0.5), + dtype="float32", + ), + ), + ) + ) + ) + + T.Cast( + "float32", + x[ + i0_1, + i1_1, + T.max( + T.min( + T.Cast( + "int32", + T.floor( + (T.Cast("float32", i2_1) + T.float32(0.5)) + * T.float32(0.5) + - T.float32(0.5), + dtype="float32", + ), + ) + + 1, + 39, + ), + 0, + ), + T.max( + T.min( + T.Cast( + "int32", + T.floor( + (T.Cast("float32", i3_1) + T.float32(0.5)) + * T.float32(0.5) + - T.float32(0.5), + dtype="float32", + ), + ) + + 1, + 39, + ), + 0, + ), + ], + ) + * ( + (T.Cast("float32", i3_1) + T.float32(0.5)) * T.float32(0.5) + - T.float32(0.5) + - T.Cast( + "float32", + T.Cast( + "int32", + T.floor( + (T.Cast("float32", i3_1) + T.float32(0.5)) * T.float32(0.5) + - T.float32(0.5), + dtype="float32", + ), + ), + ) + ) + ) + * ( + (T.Cast("float32", i2_1) + T.float32(0.5)) * T.float32(0.5) + - T.float32(0.5) + - T.Cast( + "float32", + T.Cast( + "int32", + T.floor( + (T.Cast("float32", i2_1) + T.float32(0.5)) * T.float32(0.5) + - T.float32(0.5), + dtype="float32", + ), + ), + ) + ), + ) + + +@T.prim_func +def cached_bilinear_resize( + x: T.Buffer[(1, 3, 40, 40), "float16"], resize: T.Buffer[(1, 3, 80, 80), "float16"] +): + index_var_0 = T.alloc_buffer([80], dtype="float32", strides=[1]) + index_var_1 = T.alloc_buffer([80], dtype="int32", strides=[1]) + index_var_2 = T.alloc_buffer([80], dtype="int32", strides=[1]) + for ax0 in T.serial(80): + with T.block("index_0"): + v0 = T.axis.spatial(80, ax0) + T.reads() + T.writes(index_var_0[v0]) + index_var_0[v0] = ( + (T.Cast("float32", v0) + T.float32(0.5)) * T.float32(0.5) + - T.float32(0.5) + - T.Cast( + "float32", + T.Cast( + "int32", + T.floor( + (T.Cast("float32", v0) + T.float32(0.5)) * T.float32(0.5) + - T.float32(0.5), + dtype="float32", + ), + ), + ) + ) + for ax0 in T.serial(80): + with T.block("index_1"): + v0 = T.axis.spatial(80, ax0) + T.reads() + T.writes(index_var_1[v0]) + index_var_1[v0] = T.Cast( + "int32", + T.floor( + (T.Cast("float32", v0) + T.float32(0.5)) * T.float32(0.5) - T.float32(0.5), + dtype="float32", + ), + ) + for ax0 in T.serial(80): + with T.block("index_2"): + v0 = T.axis.spatial(80, ax0) + T.reads() + T.writes(index_var_2[v0]) + index_var_2[v0] = T.Cast( + "int32", + T.floor( + (T.Cast("float32", v0) + T.float32(0.5)) * T.float32(0.5) - T.float32(0.5), + dtype="float32", + ), + ) + for i0, i1, i2, i3 in T.grid(1, 3, 80, 80): + with T.block("resize"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(x[i0_1, i1_1, 0:40, 0:40]) + T.writes(resize[i0_1, i1_1, i2_1, i3_1]) + resize[i0_1, i1_1, i2_1, i3_1] = T.Cast( + "float16", + ( + T.Cast( + "float32", + x[ + i0_1, + i1_1, + T.max(T.min(index_var_1[i2_1], 39), 0), + T.max(T.min(index_var_2[i3_1], 39), 0), + ], + ) + * (T.float32(1) - index_var_0[i3_1]) + + T.Cast( + "float32", + x[ + i0_1, + i1_1, + T.max(T.min(index_var_1[i2_1], 39), 0), + T.max(T.min(index_var_2[i3_1] + 1, 39), 0), + ], + ) + * index_var_0[i3_1] + ) + * ( + T.float32(1) + - ( + (T.Cast("float32", i2_1) + T.float32(0.5)) * T.float32(0.5) + - T.float32(0.5) + - T.Cast("float32", index_var_1[i2_1]) + ) + ) + + ( + T.Cast( + "float32", + x[ + i0_1, + i1_1, + T.max(T.min(index_var_1[i2_1] + 1, 39), 0), + T.max(T.min(index_var_2[i3_1], 39), 0), + ], + ) + * (T.float32(1) - index_var_0[i3_1]) + + T.Cast( + "float32", + x[ + i0_1, + i1_1, + T.max(T.min(index_var_1[i2_1] + 1, 39), 0), + T.max(T.min(index_var_2[i3_1] + 1, 39), 0), + ], + ) + * index_var_0[i3_1] + ) + * ( + (T.Cast("float32", i2_1) + T.float32(0.5)) * T.float32(0.5) + - T.float32(0.5) + - T.Cast("float32", index_var_1[i2_1]) + ), + ) + + +def test_basic_cache_index(): sch = tvm.tir.Schedule(resize, debug_mask="all") block = sch.get_block("A") - sch.cache_index(block, 0) + sch.cache_index(block, "global") tvm.ir.assert_structural_equal(resize_cache_index, sch.mod["main"]) verify_trace_roundtrip(sch=sch, mod=resize) +def test_resize_bilinear_cache_index(): + sch = tvm.tir.Schedule(bilinear_resize, debug_mask="all") + block = sch.get_block("resize") + sch.cache_index(block, "global", 4) + tvm.ir.assert_structural_equal(sch.mod["main"], cached_bilinear_resize) + verify_trace_roundtrip(sch=sch, mod=bilinear_resize) + + if __name__ == "__main__": tvm.testing.main()