diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 9ec2841ebd5e..9ba6f1a311de 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -681,6 +681,23 @@ class ScheduleNode : public runtime::Object { */ virtual void PadEinsum(const BlockRV& block_rv, const Array& padding) = 0; + /******** Schedule: Buffer transformation ********/ + /*! + * \brief Compute the target buffer via rolling buffering. + * \details This primitive selects the outermost rollable axis with a positive bound overlap that + * appears in the block's ancestor loops as `rolling axis`, fold and circularize the buffer along + * the rolling dimension, append block predicate to avoid recomputing overlapping elements. + * It requires: + * 1) The buffer to be an intermediate buffer defined via `alloc_buffer`. + * 2) The LCA of the producer and consumer of the buffer is a for loop, typically, + * the producer and consumer of the buffer are cascaded through compute_at. + * 3) The access region of the buffer has at least one dimension that contains + * a positive bound overlap. + * \param block_rv The producer block of the buffer. + * \param write_buffer_index The index of the buffer in block's write region. + */ + virtual void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) = 0; + /******** Schedule: Misc ********/ /*! \brief A no-op that marks the start of postprocessing phase of scheduling */ virtual void EnterPostproc() = 0; diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 4814271f4023..832ca1d2e2e0 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -3042,6 +3042,114 @@ def after_pad_einsum( self, block, padding ) + ######## Schedule: Buffer transformation ######## + + @type_checked + def rolling_buffer( + self, + block: Union[BlockRV, str], + write_buffer_index: int, + ) -> None: + """Compute the target buffer via rolling buffering, select the outermost rollable + axis with a positive bound overlap that appears in the block's ancestor loops + as `rolling axis`, fold and circularize the buffer along the rolling dimension, + append block predicate to avoid recomputing overlapping elements. It requires: + + 1) The block is not an output block and has only RAW dependencies. + + 2) The buffer to be an intermediate buffer defined via `alloc_buffer`. + + 3) The LCA of the producer and consumer of the buffer is a for loop, typically, + the producer and consumer of the buffer are cascaded through compute_at. + + 4) The access region of the buffer has at least one dimension that contains + a positive bound overlap. + + Parameters + ---------- + block : Union[BlockRV, str] + The producer block of the buffer. + write_buffer_index : int + The index of the buffer in block's write region. + + Examples + -------- + + Before rolling_buffer, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_rolling_buffer( + A: T.Buffer[(12, 12), "int8"], C: T.Buffer[(8, 8), "int8"] + ) -> None: + # body + # with T.block("root") + B = T.alloc_buffer([10, 10], dtype="int8") + for i0, i1 in T.grid(2, 2): + for ax0, ax1, ax2, ax3 in T.grid(6, 6, 3, 3): + with T.block("B"): + ax0_1 = T.axis.spatial(10, i0 * 4 + ax0) + ax1_1 = T.axis.spatial(10, i1 * 4 + ax1) + rv0, rv1 = T.axis.remap("RR", [ax2, ax3]) + B[ax0_1, ax1_1] = T.max( + B[ax0_1, ax1_1], A[ax0_1 + rv0, ax1_1 + rv1] + ) + for ax0, ax1, ax2, ax3 in T.grid(4, 4, 3, 3): + with T.block("C"): + ax0_1 = T.axis.spatial(8, i0 * 4 + ax0) + ax1_1 = T.axis.spatial(8, i1 * 4 + ax1) + rv0, rv1 = T.axis.remap("RR", [ax2, ax3]) + C[ax0_1, ax1_1] = T.max( + C[ax0_1, ax1_1], B[ax0_1 + rv0, ax1_1 + rv1] + ) + + Create the schedule and do rolling_buffer: + + .. code-block:: python + + sch = tir.Schedule(before_rolling_buffer) + sch.rolling_buffer(sch.get_block("B"), write_buffer_index=0) + print(sch.mod["main"].script()) + + After applying rolling_buffer, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_rolling_buffer( + A: T.Buffer[(12, 12), "int8"], + C: T.Buffer[(8, 8), "int8"] + ) -> None: + # body + # with T.block("root") + B = T.alloc_buffer([6, 10], dtype="int8") + for i0, i1 in T.grid(2, 2): + for ax0, ax1, ax2, ax3 in T.grid(6, 6, 3, 3): + with T.block("B"): + T.where((i0 < 1 or 2 <= ax0) and (i1 < 1 or 2 <= ax1)) + ax0_1 = T.axis.spatial(10, i0 * 4 + ax0) + ax1_1 = T.axis.spatial(10, i1 * 4 + ax1) + rv0, rv1 = T.axis.remap("RR", [ax2, ax3]) + B[ax0_1 % 6, ax1_1] = T.max( + B[ax0_1 % 6, ax1_1], A[ax0_1 + rv0, ax1_1 + rv1] + ) + for ax0, ax1, ax2, ax3 in T.grid(4, 4, 3, 3): + with T.block("C"): + ax0_1 = T.axis.spatial(8, i0 * 4 + ax0) + ax1_1 = T.axis.spatial(8, i1 * 4 + ax1) + rv0, rv1 = T.axis.remap("RR", [ax2, ax3]) + C[ax0_1, ax1_1] = T.max( + C[ax0_1, ax1_1], B[ax0_1 % 6 + rv0, ax1_1 + rv1] + ) + + Note + ---- + The region_cover property of the consumer block of the target buffer will become false. + """ + block = self._normalize_block_arg(block) + return _ffi_api.ScheduleRollingBuffer(self, block, write_buffer_index) # type: ignore # pylint: disable=no-member + ########## Schedule: Misc ########## @type_checked diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 3960087cf745..5a7e3b15aea4 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -801,6 +801,8 @@ void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_ this->state_->DebugVerify(); } +/******** Schedule: Padding ********/ + BlockRV ConcreteScheduleNode::DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); @@ -816,6 +818,16 @@ void ConcreteScheduleNode::PadEinsum(const BlockRV& block_rv, const Arrayerror_render_level_); this->state_->DebugVerify(); } + +/******** Schedule: Buffer Transformation ********/ + +void ConcreteScheduleNode::RollingBuffer(const BlockRV& block_rv, int write_buffer_index) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::RollingBuffer(state_, this->GetSRef(block_rv), write_buffer_index); + TVM_TIR_SCHEDULE_END("rolling-buffer", this->error_render_level_); + this->state_->DebugVerify(); +} + /******** Schedule: Misc ********/ } // namespace tir diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index bfdc082d4ce6..9e001b139751 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -153,6 +153,8 @@ class ConcreteScheduleNode : public ScheduleNode { const Array& axis_separators) override; /******** Schedule: Padding decomposition ********/ BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) override; + /******** Schedule: Buffer transformation ********/ + void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) override; /******** Schedule: Misc ********/ void EnterPostproc() override {} diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 88331fb5b9d3..02ca0dcab432 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -524,6 +524,22 @@ TVM_DLL StmtSRef DecomposePadding(ScheduleState self, const StmtSRef& block_sref TVM_DLL void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array& padding); +/******** Schedule: Buffer transformation ********/ +/*! + * \brief Compute the target buffer via rolling buffering. + * \details This primitive selects the outermost rollable axis with a positive bound overlap that + * appears in the block's ancestor loops as `rolling axis`, fold and circularize the buffer along + * the rolling dimension, append block predicate to avoid recomputing overlapping elements. + * It requires: + * 1) The buffer to be an intermediate buffer defined via `alloc_buffer`. + * 2) The LCA of the producer and consumer of the buffer is a for loop, typically, + * the producer and consumer of the buffer are cascaded through compute_at. + * 3) The access region of the buffer has at least one dimension that contains + * a positive bound overlap. + * \param block_rv The producer block of the buffer. + * \param write_buffer_index The index of the buffer in block's write region. + */ +TVM_DLL void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index); /******** Schedule: Misc ********/ } // namespace tir diff --git a/src/tir/schedule/primitive/rolling_buffer.cc b/src/tir/schedule/primitive/rolling_buffer.cc new file mode 100644 index 000000000000..c01d6c568fcd --- /dev/null +++ b/src/tir/schedule/primitive/rolling_buffer.cc @@ -0,0 +1,474 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../ir_comparator.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +namespace { + +struct RollingBufferInfo { + Buffer old_buffer; + Buffer new_buffer; + int rolling_axis; + PrimExpr rolling_extent; + std::vector axis_overlaps; + std::vector> axis_iter_vars; + /*! \brief The map used for ScheduleStateNode::Replace. */ + Map block_reuse; +}; + +BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const BufferRegion& buffer_region, + const Map& dom_map) { + Array relaxed_intsets = + arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), dom_map); + Region relaxed_region; + relaxed_region.reserve(relaxed_intsets.size()); + for (size_t i = 0; i < relaxed_intsets.size(); ++i) { + relaxed_region.push_back( + relaxed_intsets[i].CoverRange(Range::FromMinExtent(0, buffer_region->buffer->shape[i]))); + } + return BufferRegion(buffer_region->buffer, relaxed_region); +} + +class RollingBufferDependencyError : public ScheduleError { + public: + explicit RollingBufferDependencyError(IRModule mod, Block block) + : mod_(mod), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: The target block is required to have only RAW dependencies"; + } + + String DetailRenderTemplate() const final { + return "The target block {0} is required to have only RAW dependencies"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + /*! + * \brief Check if the block has only RAW dependencies. + * \param self The schedule state + * \param block_sref The sref of the block to be checked + * \param scope_root_sref The sref of the scope root + * \throw ScheduleError if the block has WAW or WAR dependency. + */ + static void Check(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + BlockScope scope = self->GetBlockScope(scope_root_sref); + for (const Dependency& producers : scope->GetDepsByDst(block_sref)) { + if (!(producers->kind == DepKind::kRAW)) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + throw RollingBufferDependencyError(self->mod, GetRef(block)); + } + } + for (const Dependency& consumers : scope->GetDepsBySrc(block_sref)) { + if (!(consumers->kind == DepKind::kRAW)) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + throw RollingBufferDependencyError(self->mod, GetRef(block)); + } + } + } + + private: + IRModule mod_; + Block block_; +}; + +class RollingBufferMatchError : public ScheduleError { + public: + RollingBufferMatchError(IRModule mod, Block block, BufferRegion buffer_region) + : mod_(mod), block_(block), buffer_region_(buffer_region) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer expect the buffer region to have at least one dimention" + "matching the rolling pattern such as: hh.outer * stride + hh.inner"; + } + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The target buffer " << buffer_region_->buffer->name << " with region " + << buffer_region_->region + << " should have at least one dimension range that matches a rolling pattern " + "such as hh.outer * stride + hh.inner. "; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + BufferRegion buffer_region_; +}; + +class RollingBufferInsertionError : public ScheduleError { + public: + RollingBufferInsertionError(IRModule mod, Buffer buffer, Block block) + : mod_(mod), buffer_(std::move(buffer)), block_(block) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer injection is invalid, the lca of the access " + "location of the target buffer is not a for loop. "; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "rolling_buffer injection is invalid. The block {0} should be tiled so that " + << "the lca of the access location of the target buffer " << buffer_->name + << " is a for loop. "; + return os.str(); + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Buffer buffer_; + Block block_; +}; + +class RollingBufferInfoCollector { + public: + static RollingBufferInfo CheckAndGetRollingBufferInfo(const IRModule& mod, + const StmtSRef& block_sref, + const BufferRegion& buffer_region) { + RollingBufferInfoCollector collector; + if (!collector.MatchRollingBuffer(block_sref, buffer_region)) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + throw RollingBufferMatchError(mod, GetRef(block), buffer_region); + } + return collector.info_; + } + + private: + bool MatchRollingBuffer(const StmtSRef& block_sref, const BufferRegion& buffer_region) { + const Buffer& buffer = buffer_region->buffer; + const Region& region = buffer_region->region; + + std::vector> bound_iter_vars; + std::vector bound_overlaps; + + arith::PVar p_var; + arith::PVar p_stride, p_divisor; + for (auto bound : region) { + auto stride = 0; + auto divisor = 1; + + Optional iter_var; + if (floordiv((p_var * p_stride), p_divisor).Match(bound->min)) { + // Handle the case of fractional strides + // They take this form: floordiv(hh.outer, 2) + // Strip the floordiv and keep track of the divisor + iter_var = p_var.Eval(); + divisor = p_divisor.Eval()->value; + stride = std::ceil(static_cast(p_stride.Eval()->value) / divisor); + } else if ((p_var * p_stride).Match(bound->min)) { + // The bound is the iter var multiplied by the stride + iter_var = p_var.Eval(); + stride = p_stride.Eval()->value; + } else if (p_var.Match(bound->min)) { + // If the bound is just a Var, that implies the stride is 1 + iter_var = p_var.Eval(); + stride = 1; + } else if (is_const_int(bound->min)) { + // If the bound is an int, we can't roll over it + iter_var = NullOpt; + } else { + // If all of the above matches fail, we're in unknown behaviour + return false; + } + auto bound_overlap = 0; + if (iter_var.defined()) { + auto extent = Downcast(bound->extent)->value; + bound_overlap = extent - stride; + // Since Pass CompactBufferAllocation will be responsible for compacting the buffer + // allocation region, there is no need to roll over the axis where the overlap is not + // positive, so reset iter_var to NullOpt. + if (bound_overlap <= 0) { + iter_var = NullOpt; + } + } + bound_iter_vars.push_back(iter_var); + bound_overlaps.push_back(bound_overlap); + } + + Array loop_srefs = GetLoops(block_sref); + // Pick the outermost iter_var that's mentioned in the bounds + // to be the rolling axis + Optional roll_iter_var; + int roll_axis; + for (const tir::StmtSRef& loop_sref : loop_srefs) { + auto loop_var = loop_sref->StmtAs()->loop_var; + + auto it{std::find_if(bound_iter_vars.begin(), bound_iter_vars.end(), [&](Optional var) { + return var && (var.get() == loop_var.get()); + })}; + if (it != bound_iter_vars.end()) { + auto i = std::distance(bound_iter_vars.begin(), it); + roll_iter_var = loop_var; + roll_axis = i; + break; + } + } + + if (!roll_iter_var.defined()) { + return false; + } + Array new_shape = buffer->shape; + new_shape.Set(roll_axis, region[roll_axis]->extent); + Buffer new_buffer = buffer; + new_buffer.CopyOnWrite()->shape = new_shape; + + info_.old_buffer = buffer; + info_.new_buffer = new_buffer; + info_.rolling_axis = roll_axis; + info_.rolling_extent = region[roll_axis]->extent; + info_.axis_overlaps = bound_overlaps; + info_.axis_iter_vars = bound_iter_vars; + + return true; + } + + RollingBufferInfo info_; +}; + +class RollingBufferRewriter : public StmtExprMutator { + public: + static Stmt Rewrite(const StmtSRef& scope_sref, RollingBufferInfo* info) { + RollingBufferRewriter rewriter(scope_sref, info); + return rewriter(GetRef(scope_sref->stmt)); + } + + private: + explicit RollingBufferRewriter(const StmtSRef& scope_sref, RollingBufferInfo* info) + : scope_sref_(scope_sref), info_(info) {} + + void RewriteAccessRegion(Array* old_access_regions, + const Array& infered_access_regions) { + auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) { + if (buffer_region->buffer.same_as(info_->old_buffer)) { + ICHECK(infered_access_regions.size() == 1); + return infered_access_regions[0]; + } + return buffer_region; + }; + (*old_access_regions).MutateByApply(fmutate); + } + + void RewriteBufferAccess(Buffer* buffer, Array* indices) const { + Array new_indices; + new_indices.reserve(indices->size()); + // First modify the access indices to use modulo arithmetic + // for the rolling axis + for (size_t i = 0; i < indices->size(); ++i) { + if (static_cast(i) == info_->rolling_axis) { + new_indices.push_back(FloorMod((*indices)[i], info_->rolling_extent)); + } else { + new_indices.push_back((*indices)[i]); + } + } + // Replace the accessed buffer with the new buffer. + *buffer = info_->new_buffer; + *indices = std::move(new_indices); + } + + Stmt VisitStmt_(const BlockNode* block) final { + Block old_stmt = GetRef(block); + Block stmt = Downcast(StmtExprMutator::VisitStmt_(block)); + BlockNode* n = stmt.CopyOnWrite(); + if (block == scope_sref_->stmt) { + Array new_alloc_buffers; + for (const Buffer& buffer : stmt->alloc_buffers) { + if (buffer != info_->old_buffer) { + new_alloc_buffers.push_back(buffer); + } else { + new_alloc_buffers.push_back(info_->new_buffer); + } + } + n->alloc_buffers = std::move(new_alloc_buffers); + } else { + Array new_iter_vars; + for (size_t i = 0; i < stmt->iter_vars.size(); ++i) { + auto old_iter_var = stmt->iter_vars[i]; + if (static_cast(i) == info_->rolling_axis) { + // All inner loops of the rolling axis has a loop carried dependency + // (i.e. each iteration calculation of the rolling axis depends on + // the calculation results of all the historical iterations of inner loops), + // so annotate the iteration type of the rolling axis as 'opaque', + // avoid the iterative range of its inner loop from being compressed + // during lowering phase. + IterVar new_iter_var = + IterVar(old_iter_var->dom, old_iter_var->var, IterVarType::kOpaque); + new_iter_vars.push_back(new_iter_var); + } else { + new_iter_vars.push_back(old_iter_var); + } + } + Map buffer_data_to_buffer = {{info_->new_buffer->data, info_->new_buffer}}; + auto infered_access_regions = GetBlockReadWriteRegion(stmt, buffer_data_to_buffer); + + n->iter_vars = std::move(new_iter_vars); + RewriteAccessRegion(&n->reads, infered_access_regions[0]); + RewriteAccessRegion(&n->writes, infered_access_regions[1]); + } + info_->block_reuse.Set(old_stmt, stmt); + return std::move(stmt); + } + + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + BlockRealize stmt = Downcast(StmtExprMutator::VisitStmt_(realize)); + // Append block predicate to avoid recomputing elements. + if (rewrite_block_predicate_) { + rewrite_block_predicate_ = false; + PrimExpr condition = stmt->predicate; + for (size_t i = 0; i < info_->axis_iter_vars.size(); ++i) { + auto iter_var = info_->axis_iter_vars[i]; + if (iter_var && info_->axis_overlaps[i] > 0) { + Var var = iter_var.value(); + const Map dmap = {std::make_pair(var, arith::IntSet::Interval(0, 0))}; + auto iter_value = realize->iter_values[i]; + arith::Analyzer analyzer; + auto term_2 = analyzer.int_set(iter_value, dmap).min(); + condition = analyzer.Simplify( + And(condition, Or(LT(var, 1), GE(term_2, info_->axis_overlaps[i])))); + } + } + BlockRealizeNode* n = stmt.CopyOnWrite(); + n->predicate = condition; + } + return std::move(stmt); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore stmt = Downcast(StmtExprMutator::VisitStmt_(op)); + if (stmt->buffer.same_as(info_->old_buffer)) { + BufferStoreNode* n = stmt.CopyOnWrite(); + RewriteBufferAccess(&n->buffer, &n->indices); + // Need to add predicate to the current block to avoid recomputing elements. + rewrite_block_predicate_ = true; + } + return std::move(stmt); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad stmt = Downcast(StmtExprMutator::VisitExpr_(op)); + if (stmt->buffer.same_as(info_->old_buffer)) { + BufferLoadNode* n = stmt.CopyOnWrite(); + RewriteBufferAccess(&n->buffer, &n->indices); + } + return std::move(stmt); + } + + private: + const StmtSRef& scope_sref_; + RollingBufferInfo* info_; + bool rewrite_block_predicate_ = false; +}; + +} // namespace + +void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index) { + /*! + * Check + * - The block is not an output block. + * - The block has only RAW dependencies. + * - The block is tiled and there is access overlap between adjacent tiles. + * Mutate + * - Select the outermost rollable axis appeared in the block's loop nest + * as the 'rolling axis', trim the target buffer from the rolling axis. + * - Use modulo arithmetic to modify the target buffer's read and load + * indices to circularize the buffer along the rolling dimension. + * - Append block predicate to avoid recomputing overlapping elements. + */ + Map dom_map; + const BlockRealize& realize = GetBlockRealize(self, block_sref); + const Block& block = realize->block; + + // Step 1. Checking index, getting the target buffer region and the parent scope. + const BufferRegion& buffer_region = + GetNthAccessBufferRegion(self, block, write_buffer_index, BufferIndexType::kWrite); + StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); + // Step 2. Check if the target block is not an output block and has only RAW dependencies. + CheckNotOutputBlock(self, block_sref, scope_root_sref); + RollingBufferDependencyError::Check(self, block_sref, scope_root_sref); + + // Step 3. Find the lca of the access location of the target buffer and relax the buffer + Array loop_srefs = GetLoops(block_sref); + Array consumers_sref = GetConsumers(self, block_sref); + consumers_sref.push_back(block_sref); + StmtSRef lca = GetSRefLowestCommonAncestor(consumers_sref); + if (!lca->StmtAs()) { + throw RollingBufferInsertionError(self->mod, buffer_region->buffer, block); + } + + for (auto it = loop_srefs.rbegin(); it != loop_srefs.rend(); ++it) { + auto stmt = *it; + // Stop at the lca of all the rolling_buffer access points; + if (stmt == lca) { + break; + } + For cur_loop = GetRef(stmt->StmtAs()); + Range range = Range::FromMinExtent(cur_loop->min, cur_loop->extent); + dom_map.Set(cur_loop->loop_var, arith::IntSet::FromRange(range)); + } + BufferRegion relaxed_region = GetRelaxedBufferRegion(realize, buffer_region, dom_map); + + // Step 4. Find a valid rolling axis and collect bound overlaps on the target buffer. + RollingBufferInfo info = RollingBufferInfoCollector::CheckAndGetRollingBufferInfo( + self->mod, block_sref, relaxed_region); + // Step 5. Mutate IR to apply rolling access pattern. + Stmt new_scope_root = RollingBufferRewriter::Rewrite(scope_root_sref, &info); + + // Step 6. Update schedule states + self->Replace(scope_root_sref, new_scope_root, info.block_reuse); + // Step 7. Regenerate block info from the root block, because `region_cover` for the target block + // and `stage_pipeline` for the root block are no longer satisfied after rolling buffer injection. + self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, self->stmt2ref.at(new_scope_root.get()))); +} + +struct RollingBufferTraits : public UnpackedInstTraits { + static constexpr const char* kName = "RollingBuffer"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer write_buffer_index) { + return sch->RollingBuffer(block, write_buffer_index.IntValue()); + } + + static String UnpackedAsPython(Array outputs, String block, Integer write_buffer_index) { + PythonAPICall py("rolling_buffer"); + py.Input("block", block); + py.Input("write_buffer_index", write_buffer_index); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(RollingBufferTraits); +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 280d0af92a8c..c50632f9026e 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -270,6 +270,9 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposePadding") .set_body_method(&ScheduleNode::DecomposePadding); TVM_REGISTER_GLOBAL("tir.schedule.SchedulePadEinsum") .set_body_method(&ScheduleNode::PadEinsum); +/******** (FFI) Buffer transformation ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRollingBuffer") + .set_body_method(&ScheduleNode::RollingBuffer); /******** (FFI) Misc ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc") .set_body_method(&ScheduleNode::EnterPostproc); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index b67b008feda4..36bfbc247183 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -561,6 +561,18 @@ void TracedScheduleNode::PadEinsum(const BlockRV& block_rv, const Array /*outputs=*/{})); } +/******** Schedule: Buffer transformation ********/ + +void TracedScheduleNode::RollingBuffer(const BlockRV& block_rv, int write_buffer_index) { + ConcreteScheduleNode::RollingBuffer(block_rv, write_buffer_index); + static const InstructionKind& kind = InstructionKind::Get("RollingBuffer"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(write_buffer_index)}, + /*outputs=*/{})); +} + /******** Schedule: Misc ********/ void TracedScheduleNode::EnterPostproc() { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 016de60726b9..450c19326a46 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -113,6 +113,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { /******** Schedule: Padding ********/ BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) final; void PadEinsum(const BlockRV& block_rv, const Array& padding) final; + /******** Schedule: Buffer transformation ********/ + void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) final; /******** Schedule: Misc ********/ void EnterPostproc() final; }; diff --git a/tests/python/unittest/test_tir_schedule_rolling_buffer.py b/tests/python/unittest/test_tir_schedule_rolling_buffer.py new file mode 100644 index 000000000000..c55c41e451cc --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_rolling_buffer.py @@ -0,0 +1,573 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import numpy as np +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip +import pytest + + +def check_rolling_buffer( + sch: tir.Schedule, origin: tir.PrimFunc, expected: tir.PrimFunc, check_run=False +): + scheduled = sch.mod["main"] + tvm.ir.assert_structural_equal(scheduled, expected) + verify_trace_roundtrip(sch, origin) + if check_run: + in_buffer = origin.buffer_map[origin.params[0]] + out_buffer = origin.buffer_map[origin.params[1]] + in_shape = [int(_) for _ in in_buffer.shape] + out_shape = [int(_) for _ in out_buffer.shape] + x = tvm.nd.array(np.random.uniform(0, 64, in_shape).astype(in_buffer.dtype)) + y0 = tvm.nd.array(np.zeros(out_shape).astype(out_buffer.dtype)) + y1 = tvm.nd.array(np.zeros(out_shape).astype(out_buffer.dtype)) + f_origin = tvm.build(origin) + f_scheduled = tvm.build(scheduled) + f_origin(x, y0) + f_scheduled(x, y1) + tvm.testing.assert_allclose(y0.numpy(), y1.numpy()) + + +def _tile_nd(s, tile, block_name): + outer_indices = [] + inner_indices = [] + block = s.get_block(block_name) + loops = s.get_loops(block) + for i, size in enumerate(tile): + outer, inner = s.split(loops[i], [None, size]) + outer_indices.append(outer) + inner_indices.append(inner) + + s.reorder(*outer_indices, *inner_indices) + return outer_indices, inner_indices + + +def test_1d_rolling_buffer(): + @T.prim_func + def before(A: T.Buffer[(4, 12), "int32"], C: T.Buffer[(4, 8), "int32"]): + B = T.alloc_buffer((4, 10), "int32") + for c in T.serial(4): + for i in T.serial(0, 10): + for k in T.serial(3): + with T.block("B"): + cc, vi, vk = T.axis.remap("SSR", [c, i, k]) + with T.init(): + B[cc, vi] = 0 + B[cc, vi] = B[cc, vi] + A[cc, vi + vk] + for i in T.serial(0, 8): + for k in T.serial(3): + with T.block("C"): + cc, vi, vk = T.axis.remap("SSR", [c, i, k]) + with T.init(): + C[cc, vi] = 0 + C[cc, vi] = C[cc, vi] + B[cc, vi + vk] + + @T.prim_func + def expected(A: T.Buffer[(4, 12), "int32"], C: T.Buffer[(4, 8), "int32"]): + B = T.alloc_buffer([4, 6], dtype="int32") + for c, i_0 in T.grid(4, 2): + for ax0, ax1 in T.grid(6, 3): + with T.block("B"): + T.where(i_0 < 1 or 2 <= ax0) + cc = T.axis.spatial(4, c) + vi = T.axis.opaque(10, i_0 * 4 + ax0) + vk = T.axis.reduce(3, ax1) + T.reads(A[cc, vi + vk]) + T.writes(B[cc, vi % 6]) + with T.init(): + B[cc, vi % 6] = 0 + B[cc, vi % 6] = B[cc, vi % 6] + A[cc, vi + vk] + for i_1, k in T.grid(4, 3): + with T.block("C"): + cc = T.axis.spatial(4, c) + vi = T.axis.opaque(8, i_0 * 4 + i_1) + vk = T.axis.reduce(3, k) + T.reads(B[cc, (vi + vk) % 6]) + T.writes(C[cc, vi]) + with T.init(): + C[cc, vi] = 0 + C[cc, vi] = C[cc, vi] + B[cc, (vi + vk) % 6] + + sch = tir.Schedule(before, debug_mask="all") + _, i, _ = sch.get_loops(sch.get_block("C")) + io, _ = sch.split(i, [2, 4]) + sch.compute_at(sch.get_block("B"), io) + sch.rolling_buffer(sch.get_block("B"), 0) + check_rolling_buffer(sch, before, expected, check_run=True) + + +@T.prim_func +def cascade_2_max_pool2d(A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"]): + B = T.alloc_buffer([1, 10, 10, 16], dtype="int8") + for i0, i1, i2, i3, i4, i5 in T.grid(1, 10, 10, 16, 3, 3): + with T.block("B"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + with T.init(): + B[ax0, ax1, ax2, ax3] = T.int8(-128) + B[ax0, ax1, ax2, ax3] = T.max(B[ax0, ax1, ax2, ax3], A[ax0, ax1 + rv0, ax2 + rv1, ax3]) + for i0, i1, i2, i3, i4, i5 in T.grid(1, 8, 8, 16, 3, 3): + with T.block("C"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max(C[ax0, ax1, ax2, ax3], B[ax0, ax1 + rv0, ax2 + rv1, ax3]) + + +@T.prim_func +def cascade_3_max_pool2d_with_stride( + A: T.Buffer[(1, 24, 24, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"] +): + B_0 = T.alloc_buffer([1, 22, 22, 16], dtype="int8") + B_1 = T.alloc_buffer([1, 10, 10, 16], dtype="int8") + for i0, i1, i2, i3, i4, i5 in T.grid(1, 22, 22, 16, 3, 3): + with T.block("B_0"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + with T.init(): + B_0[ax0, ax1, ax2, ax3] = T.int8(-128) + B_0[ax0, ax1, ax2, ax3] = T.max( + B_0[ax0, ax1, ax2, ax3], A[ax0, ax1 + rv0, ax2 + rv1, ax3] + ) + for i0, i1, i2, i3, i4, i5 in T.grid(1, 10, 10, 16, 3, 3): + with T.block("B_1"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + with T.init(): + B_1[ax0, ax1, ax2, ax3] = T.int8(-128) + B_1[ax0, ax1, ax2, ax3] = T.max( + B_1[ax0, ax1, ax2, ax3], B_0[ax0, ax1 * 2 + rv0, ax2 * 2 + rv1, ax3] + ) + for i0, i1, i2, i3, i4, i5 in T.grid(1, 8, 8, 16, 3, 3): + with T.block("C"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B_1[ax0, ax1 + rv0, ax2 + rv1, ax3] + ) + + +def test_cascade_max_pool2d_w_tiled(): + @T.prim_func + def expected(A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"]): + B = T.alloc_buffer([1, 10, 6, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 1, 2, 1): + for ax0, ax1, ax2, ax3, ax4 in T.grid(10, 6, 16, 3, 3): + with T.block("B"): + T.where(i2_0 < 1 or 2 <= ax1) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.spatial(10, ax0) + ax2_1 = T.axis.opaque(10, i2_0 * 4 + ax1) + ax3_1, rv0, rv1 = T.axis.remap("SRR", [ax2, ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B[ax0_1, ax1_1, ax2_1 % 6, ax3_1]) + with T.init(): + B[ax0_1, ax1_1, ax2_1 % 6, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1, ax2_1 % 6, ax3_1] = T.max( + B[ax0_1, ax1_1, ax2_1 % 6, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 8, 4, 16, 3, 3): + with T.block("C"): + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.spatial(8, i1_0 * 8 + i1_1) + ax2 = T.axis.opaque(8, i2_0 * 4 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, ax1 + rv0, (ax2 + rv1) % 6, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, ax1 + rv0, (ax2 + rv1) % 6, ax3] + ) + + sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") + oi, _ = _tile_nd(sch, [1, 8, 4, 16], "C") + sch.compute_at(sch.get_block("B"), oi[-1]) + sch.rolling_buffer(sch.get_block("B"), 0) + check_rolling_buffer(sch, cascade_2_max_pool2d, expected, check_run=True) + + +def test_cascade_max_pool2d_h_tiled(): + @T.prim_func + def expected(A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"]): + B = T.alloc_buffer([1, 6, 10, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 2, 1, 1): + for ax0, ax1, ax2, ax3, ax4 in T.grid(6, 10, 16, 3, 3): + with T.block("B"): + T.where(i1_0 < 1 or 2 <= ax0) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.opaque(10, i1_0 * 4 + ax0) + ax2_1 = T.axis.spatial(10, ax1) + ax3_1, rv0, rv1 = T.axis.remap("SRR", [ax2, ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B[ax0_1, ax1_1 % 6, ax2_1, ax3_1]) + with T.init(): + B[ax0_1, ax1_1 % 6, ax2_1, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1 % 6, ax2_1, ax3_1] = T.max( + B[ax0_1, ax1_1 % 6, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 8, 16, 3, 3): + with T.block("C"): + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.opaque(8, i1_0 * 4 + i1_1) + ax2 = T.axis.spatial(8, i2_0 * 8 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, (ax1 + rv0) % 6, ax2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, (ax1 + rv0) % 6, ax2 + rv1, ax3] + ) + + sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") + io, _ = _tile_nd(sch, [1, 4, 8, 16], "C") + sch.compute_at(sch.get_block("B"), io[-1]) + sch.rolling_buffer(sch.get_block("B"), 0) + check_rolling_buffer(sch, cascade_2_max_pool2d, expected, check_run=True) + + +def test_cascade_max_pool2d_h_w_c_tiled(): + @T.prim_func + def expected(A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"]): + B = T.alloc_buffer([1, 6, 10, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 2, 2, 2): + for ax0, ax1, ax2, ax3, ax4 in T.grid(6, 6, 8, 3, 3): + with T.block("B"): + T.where((i1_0 < 1 or 2 <= ax0) and (i2_0 < 1 or 2 <= ax1)) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.opaque(10, i1_0 * 4 + ax0) + ax2_1 = T.axis.spatial(10, i2_0 * 4 + ax1) + ax3_1 = T.axis.spatial(16, i3_0 * 8 + ax2) + rv0, rv1 = T.axis.remap("RR", [ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B[ax0_1, ax1_1 % 6, ax2_1, ax3_1]) + with T.init(): + B[ax0_1, ax1_1 % 6, ax2_1, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1 % 6, ax2_1, ax3_1] = T.max( + B[ax0_1, ax1_1 % 6, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 4, 8, 3, 3): + with T.block("C"): + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.opaque(8, i1_0 * 4 + i1_1) + ax2 = T.axis.spatial(8, i2_0 * 4 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 8 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, (ax1 + rv0) % 6, ax2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, (ax1 + rv0) % 6, ax2 + rv1, ax3] + ) + + sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") + io, _ = _tile_nd(sch, [1, 4, 4, 8], "C") + sch.compute_at(sch.get_block("B"), io[-1]) + sch.rolling_buffer(sch.get_block("B"), 0) + check_rolling_buffer(sch, cascade_2_max_pool2d, expected, check_run=True) + + +def test_cascade_max_pool2d_non_perfect_tiled(): + @T.prim_func + def expected(A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"]) -> None: + B = T.alloc_buffer([1, 8, 10, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 2, 2, 1): + for ax0, ax1, ax2, ax3, ax4 in T.grid(8, 8, 16, 3, 3): + with T.block("B"): + T.where( + i1_0 * 6 + ax0 < 10 + and i2_0 * 6 + ax1 < 10 + and (i1_0 < 1 or 2 <= ax0) + and (i2_0 < 1 or 2 <= ax1) + ) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.opaque(10, i1_0 * 6 + ax0) + ax2_1 = T.axis.spatial(10, i2_0 * 6 + ax1) + ax3_1, rv0, rv1 = T.axis.remap("SRR", [ax2, ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B[ax0_1, ax1_1 % 8, ax2_1, ax3_1]) + with T.init(): + B[ax0_1, ax1_1 % 8, ax2_1, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1 % 8, ax2_1, ax3_1] = T.max( + B[ax0_1, ax1_1 % 8, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 6, 6, 16, 3, 3): + with T.block("C"): + T.where(i1_0 * 6 + i1_1 < 8 and i2_0 * 6 + i2_1 < 8) + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.opaque(8, i1_0 * 6 + i1_1) + ax2 = T.axis.spatial(8, i2_0 * 6 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, (ax1 + rv0) % 8, ax2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, (ax1 + rv0) % 8, ax2 + rv1, ax3] + ) + + sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") + io, _ = _tile_nd(sch, [1, 6, 6, 16], "C") + sch.compute_at(sch.get_block("B"), io[-1]) + sch.rolling_buffer(sch.get_block("B"), 0) + check_rolling_buffer(sch, cascade_2_max_pool2d, expected, check_run=True) + + +def test_cascade_3_max_pool2d_with_stride(): + @T.prim_func + def expected(A: T.Buffer[(1, 24, 24, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"]) -> None: + B_0 = T.alloc_buffer([1, 13, 22, 16], dtype="int8") + B_1 = T.alloc_buffer([1, 6, 10, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 2, 2, 1): + for ax0, ax1, ax2, ax3, ax4 in T.grid(13, 13, 16, 3, 3): + with T.block("B_0"): + T.where((i1_0 < 1 or 5 <= ax0) and (i2_0 < 1 or 5 <= ax1)) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.opaque(22, i1_0 * 8 + ax0) + ax2_1 = T.axis.spatial(22, i2_0 * 8 + ax1) + ax3_1, rv0, rv1 = T.axis.remap("SRR", [ax2, ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B_0[ax0_1, ax1_1 % 13, ax2_1, ax3_1]) + with T.init(): + B_0[ax0_1, ax1_1 % 13, ax2_1, ax3_1] = T.int8(-128) + B_0[ax0_1, ax1_1 % 13, ax2_1, ax3_1] = T.max( + B_0[ax0_1, ax1_1 % 13, ax2_1, ax3_1], + A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1], + ) + for ax0, ax1, ax2, ax3, ax4 in T.grid(6, 6, 16, 3, 3): + with T.block("B_1"): + T.where((i1_0 < 1 or 2 <= ax0) and (i2_0 < 1 or 2 <= ax1)) + ax0_2 = T.axis.spatial(1, 0) + ax1_2 = T.axis.opaque(10, i1_0 * 4 + ax0) + ax2_2 = T.axis.spatial(10, i2_0 * 4 + ax1) + ax3_2, rv0, rv1 = T.axis.remap("SRR", [ax2, ax3, ax4]) + T.reads(B_0[ax0_2, (ax1_2 * 2 + rv0) % 13, ax2_2 * 2 + rv1, ax3_2]) + T.writes(B_1[ax0_2, ax1_2 % 6, ax2_2, ax3_2]) + with T.init(): + B_1[ax0_2, ax1_2 % 6, ax2_2, ax3_2] = T.int8(-128) + B_1[ax0_2, ax1_2 % 6, ax2_2, ax3_2] = T.max( + B_1[ax0_2, ax1_2 % 6, ax2_2, ax3_2], + B_0[ax0_2, (ax1_2 * 2 + rv0) % 13, ax2_2 * 2 + rv1, ax3_2], + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 4, 16, 3, 3): + with T.block("C"): + ax0_3 = T.axis.spatial(1, i0_0 + i0_1) + ax1_3 = T.axis.opaque(8, i1_0 * 4 + i1_1) + ax2_3 = T.axis.spatial(8, i2_0 * 4 + i2_1) + ax3_3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B_1[ax0_3, (ax1_3 + rv0) % 6, ax2_3 + rv1, ax3_3]) + T.writes(C[ax0_3, ax1_3, ax2_3, ax3_3]) + with T.init(): + C[ax0_3, ax1_3, ax2_3, ax3_3] = T.int8(-128) + C[ax0_3, ax1_3, ax2_3, ax3_3] = T.max( + C[ax0_3, ax1_3, ax2_3, ax3_3], + B_1[ax0_3, (ax1_3 + rv0) % 6, ax2_3 + rv1, ax3_3], + ) + + sch = tir.Schedule(cascade_3_max_pool2d_with_stride, debug_mask="all") + io, _ = _tile_nd(sch, [1, 4, 4, 16], "C") + sch.compute_at(sch.get_block("B_1"), io[-1]) + sch.compute_at(sch.get_block("B_0"), io[-1]) + sch.rolling_buffer(sch.get_block("B_0"), 0) + sch.rolling_buffer(sch.get_block("B_1"), 0) + check_rolling_buffer(sch, cascade_3_max_pool2d_with_stride, expected, check_run=True) + + +def test_upscale(): + @T.prim_func + def before(A: T.Buffer[(1, 16, 16, 16), "int8"], C: T.Buffer[(1, 24, 24, 16), "int8"]) -> None: + B = T.alloc_buffer([1, 14, 14, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 5, 5, 1): + for ax0, ax1, ax2, ax3, ax4 in T.grid(5, 5, 16, 3, 3): + with T.block("B"): + T.where(i1_0 * 5 // 2 + ax0 < 14 and i2_0 * 5 // 2 + ax1 < 14) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.spatial(14, i1_0 * 5 // 2 + ax0) + ax2_1 = T.axis.spatial(14, i2_0 * 5 // 2 + ax1) + ax3_1 = T.axis.spatial(16, ax2) + rv0, rv1 = T.axis.remap("RR", [ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B[ax0_1, ax1_1, ax2_1, ax3_1]) + with T.init(): + B[ax0_1, ax1_1, ax2_1, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1, ax2_1, ax3_1] = T.max( + B[ax0_1, ax1_1, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 5, 5, 16, 3, 3): + with T.block("C"): + T.where(i1_0 * 5 + i1_1 < 24 and i2_0 * 5 + i2_1 < 24) + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.spatial(24, i1_0 * 5 + i1_1) + ax2 = T.axis.spatial(24, i2_0 * 5 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, ax1 // 2 + rv0, ax2 // 2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, ax1 // 2 + rv0, ax2 // 2 + rv1, ax3] + ) + + @T.prim_func + def expected( + A: T.Buffer[(1, 16, 16, 16), "int8"], C: T.Buffer[(1, 24, 24, 16), "int8"] + ) -> None: + B = T.alloc_buffer([1, 5, 14, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 5, 5, 1): + for ax0, ax1, ax2, ax3, ax4 in T.grid(5, 5, 16, 3, 3): + with T.block("B"): + T.where( + i1_0 * 5 // 2 + ax0 < 14 + and i2_0 * 5 // 2 + ax1 < 14 + and (i1_0 < 1 or 2 <= ax0) + and (i2_0 < 1 or 2 <= ax1) + ) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.opaque(14, i1_0 * 5 // 2 + ax0) + ax2_1 = T.axis.spatial(14, i2_0 * 5 // 2 + ax1) + ax3_1 = T.axis.spatial(16, ax2) + rv0, rv1 = T.axis.remap("RR", [ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B[ax0_1, ax1_1 % 5, ax2_1, ax3_1]) + with T.init(): + B[ax0_1, ax1_1 % 5, ax2_1, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1 % 5, ax2_1, ax3_1] = T.max( + B[ax0_1, ax1_1 % 5, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 5, 5, 16, 3, 3): + with T.block("C"): + T.where(i1_0 * 5 + i1_1 < 24 and i2_0 * 5 + i2_1 < 24) + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.opaque(24, i1_0 * 5 + i1_1) + ax2 = T.axis.spatial(24, i2_0 * 5 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, (ax1 // 2 + rv0) % 5, ax2 // 2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, (ax1 // 2 + rv0) % 5, ax2 // 2 + rv1, ax3] + ) + + sch = tir.Schedule(before, debug_mask="all") + sch.rolling_buffer(sch.get_block("B"), 0) + check_rolling_buffer(sch, before, expected, check_run=True) + + +def test_fail_rolling_buffer_multi_writers(): + @T.prim_func + def func_multi_writers( + A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 12, 12, 16), "int8"] + ): + B = T.alloc_buffer([1, 12, 12, 16], dtype="int8") + for i0, i1, i2, i3 in T.grid(1, 3, 3, 1): + for ax0, ax1, ax2 in T.grid(6, 6, 16): + with T.block("B_writer_0"): + ax0_1 = T.axis.spatial(1, i0) + ax1_1 = T.axis.spatial(12, i1 * 4 + ax0) + ax2_1 = T.axis.spatial(12, i2 * 4 + ax1) + ax3_1 = T.axis.spatial(16, ax2) + with T.init(): + B[ax0_1, ax1_1, ax2_1, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1, ax2_1, ax3_1] = A[ax0_1, ax1_1, ax2_1, ax3_1] + T.int8(1) + for ax0, ax1, ax2 in T.grid(6, 6, 16): + with T.block("B_writer_1"): + ax0_2 = T.axis.spatial(1, i0) + ax1_2 = T.axis.spatial(12, i1 * 4 + ax0) + ax2_2 = T.axis.spatial(12, i2 * 4 + ax1) + ax3_2 = T.axis.spatial(16, ax2) + with T.init(): + B[ax0_2, ax1_2, ax2_2, ax3_2] = T.int8(-128) + B[ax0_2, ax1_2, ax2_2, ax3_2] = B[ax0_2, ax1_2, ax2_2, ax3_2] + A[ + ax0_2, ax1_2, ax2_2, ax3_2 + ] * T.int8(2) + for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 4, 4, 16, 3, 3): + with T.block("C"): + ax0_3 = T.axis.spatial(1, i0 + ax0) + ax1_3 = T.axis.spatial(12, i1 * 4 + ax1) + ax2_3 = T.axis.spatial(12, i2 * 4 + ax2) + ax3_3 = T.axis.spatial(16, i3 * 16 + ax3) + rv0, rv1 = T.axis.remap("RR", [ax4, ax5]) + with T.init(): + C[ax0_3, ax1_3, ax2_3, ax3_3] = T.int8(-128) + C[ax0_3, ax1_3, ax2_3, ax3_3] = T.max( + C[ax0_3, ax1_3, ax2_3, ax3_3], B[ax0_3, ax1_3 + rv0, ax2_3 + rv1, ax3_3] + ) + + sch = tir.Schedule(func_multi_writers, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.rolling_buffer(sch.get_block("B_writer_0"), 0) + + +def test_fail_rolling_buffer_not_match(): + @T.prim_func + def func_non_overlap( + A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 12, 12, 16), "int8"] + ): + B = T.alloc_buffer([1, 12, 12, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 3, 3, 1): + for ax0, ax1, ax2 in T.grid(4, 4, 16): + with T.block("B"): + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.spatial(12, i1_0 * 4 + ax0) + ax2_1 = T.axis.spatial(12, i2_0 * 4 + ax1) + ax3 = T.axis.spatial(16, ax2) + T.reads(A[ax0_1, ax1_1, ax2_1, ax3]) + T.writes(B[ax0_1, ax1_1, ax2_1, ax3]) + with T.init(): + B[ax0_1, ax1_1, ax2_1, ax3] = T.int8(-128) + B[ax0_1, ax1_1, ax2_1, ax3] = A[ax0_1, ax1_1, ax2_1, ax3] + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 4, 16, 1, 1): + with T.block("C"): + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.spatial(12, i1_0 * 4 + i1_1) + ax2 = T.axis.spatial(12, i2_0 * 4 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, ax1 + rv0, ax2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, ax1 + rv0, ax2 + rv1, ax3] + ) + + sch = tir.Schedule(func_non_overlap, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.rolling_buffer(sch.get_block("B"), 0) + + +def test_fail_rolling_buffer_injection_invalid(): + sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") + # Block B is not compute_at to Block C, so rolling_buffer injection is invalid. + _, _ = _tile_nd(sch, [1, 4, 8, 16], "C") + _, _ = _tile_nd(sch, [1, 4, 8, 16], "B") + with pytest.raises(tvm.tir.ScheduleError): + sch.rolling_buffer(sch.get_block("B"), 0) + + +if __name__ == "__main__": + tvm.testing.main()