From 1b669df373681d5a5c023e2ddab2c1e32ce8b2d7 Mon Sep 17 00:00:00 2001 From: LiangW <732811423@qq.com> Date: Tue, 11 Oct 2022 06:53:49 +0000 Subject: [PATCH 1/3] [TIR][Primitive] Support rolling_buffer schedule primitive in TensorIR --- include/tvm/tir/schedule/schedule.h | 3 + python/tvm/tir/schedule/schedule.py | 105 ++++ src/tir/schedule/concrete_schedule.cc | 9 + src/tir/schedule/concrete_schedule.h | 2 + src/tir/schedule/primitive.h | 2 + src/tir/schedule/primitive/rolling_buffer.cc | 443 +++++++++++++++ src/tir/schedule/schedule.cc | 3 + src/tir/schedule/traced_schedule.cc | 12 + src/tir/schedule/traced_schedule.h | 2 + .../test_tir_schedule_rolling_buffer.py | 534 ++++++++++++++++++ 10 files changed, 1115 insertions(+) create mode 100644 src/tir/schedule/primitive/rolling_buffer.cc create mode 100644 tests/python/unittest/test_tir_schedule_rolling_buffer.py diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 9ec2841ebd5e..b39491eacd8a 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -681,6 +681,9 @@ class ScheduleNode : public runtime::Object { */ virtual void PadEinsum(const BlockRV& block_rv, const Array& padding) = 0; + /******** Schedule: Buffer transformation ********/ + virtual void RollingBuffer(const BlockRV& block_rv, int buffer_index) = 0; + /******** Schedule: Misc ********/ /*! \brief A no-op that marks the start of postprocessing phase of scheduling */ virtual void EnterPostproc() = 0; diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 4814271f4023..ad32cb46a9cb 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -3042,6 +3042,111 @@ def after_pad_einsum( self, block, padding ) + ######## Schedule: Buffer transformation ######## + + @type_checked + def rolling_buffer( + self, + block: Union[BlockRV, str], + buffer_index: int, + ) -> None: + """Compute the target buffer via rolling buffering, select the outermost rollable + axis with a positive bound overlap that appears in the block's ancestor loops + as `rolling axis`. It requires: + + 1) The buffer to be an intermediate buffer defined via `alloc_buffer`. + + 2) The LCA of the producer and consumer of the buffer is a for loop, typically, + the producer and consumer of the buffer are cascaded through compute_at. + + 3) The access region of the buffer has at least one dimension that contains + a positive bound overlap. + + Parameters + ---------- + block : Union[BlockRV, str] + The producer block of the buffer. + buffer_index : int + The index of the buffer in block's write region. + + Examples + -------- + + Before rolling_buffer, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_rolling_buffer( + A: T.Buffer[(12, 12), "int8"], C: T.Buffer[(8, 8), "int8"] + ) -> None: + # body + # with T.block("root") + B = T.alloc_buffer([10, 10], dtype="int8") + for i0, i1 in T.grid(2, 2): + for ax0, ax1, ax2, ax3 in T.grid(6, 6, 3, 3): + with T.block("B"): + ax0_1 = T.axis.spatial(10, i0 * 4 + ax0) + ax1_1 = T.axis.spatial(10, i1 * 4 + ax1) + rv0, rv1 = T.axis.remap("RR", [ax2, ax3]) + B[ax0_1, ax1_1] = T.max( + B[ax0_1, ax1_1], A[ax0_1 + rv0, ax1_1 + rv1] + ) + for ax0, ax1, ax2, ax3 in T.grid(4, 4, 3, 3): + with T.block("C"): + ax0_1 = T.axis.spatial(8, i0 * 4 + ax0) + ax1_1 = T.axis.spatial(8, i1 * 4 + ax1) + rv0, rv1 = T.axis.remap("RR", [ax2, ax3]) + C[ax0_1, ax1_1] = T.max( + C[ax0_1, ax1_1], B[ax0_1 + rv0, ax1_1 + rv1] + ) + + Create the schedule and do rolling_buffer: + + .. code-block:: python + + sch = tir.Schedule(before_rolling_buffer) + sch.rolling_buffer(sch.get_block("B"), buffer_index=0) + print(sch.mod["main"].script()) + + After applying rolling_buffer, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_rolling_buffer( + A: T.Buffer[(12, 12), "int8"], + C: T.Buffer[(8, 8), "int8"] + ) -> None: + # body + # with T.block("root") + B = T.alloc_buffer([6, 10], dtype="int8") + for i0, i1 in T.grid(2, 2): + for ax0, ax1, ax2, ax3 in T.grid(6, 6, 3, 3): + with T.block("B"): + T.where((i0 < 1 or 2 <= ax0) and (i1 < 1 or 2 <= ax1)) + ax0_1 = T.axis.spatial(10, i0 * 4 + ax0) + ax1_1 = T.axis.spatial(10, i1 * 4 + ax1) + rv0, rv1 = T.axis.remap("RR", [ax2, ax3]) + B[ax0_1 % 6, ax1_1] = T.max( + B[ax0_1 % 6, ax1_1], A[ax0_1 + rv0, ax1_1 + rv1] + ) + for ax0, ax1, ax2, ax3 in T.grid(4, 4, 3, 3): + with T.block("C"): + ax0_1 = T.axis.spatial(8, i0 * 4 + ax0) + ax1_1 = T.axis.spatial(8, i1 * 4 + ax1) + rv0, rv1 = T.axis.remap("RR", [ax2, ax3]) + C[ax0_1, ax1_1] = T.max( + C[ax0_1, ax1_1], B[ax0_1 % 6 + rv0, ax1_1 + rv1] + ) + + Note + ---- + The region_cover property of the consumer block of the target buffer will become false. + """ + block = self._normalize_block_arg(block) + return _ffi_api.ScheduleRollingBuffer(self, block, buffer_index) # type: ignore # pylint: disable=no-member + ########## Schedule: Misc ########## @type_checked diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 3960087cf745..ae500fa563fb 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -638,6 +638,15 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { this->state_->DebugVerify(); } +/******** Schedule: Buffer Transformation ********/ + +void ConcreteScheduleNode::RollingBuffer(const BlockRV& block_rv, int buffer_index) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::RollingBuffer(state_, this->GetSRef(block_rv), buffer_index); + TVM_TIR_SCHEDULE_END("rolling-buffer", this->error_render_level_); + this->state_->DebugVerify(); +} + /******** Schedule: Block Annotation ********/ void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index bfdc082d4ce6..e3eddd7f8395 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -153,6 +153,8 @@ class ConcreteScheduleNode : public ScheduleNode { const Array& axis_separators) override; /******** Schedule: Padding decomposition ********/ BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) override; + /******** Schedule: Buffer transformation ********/ + void RollingBuffer(const BlockRV& block_rv, int buffer_index) override; /******** Schedule: Misc ********/ void EnterPostproc() override {} diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 88331fb5b9d3..d5ce299ff036 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -395,6 +395,8 @@ TVM_DLL StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sr * \return The sref of the rfactor block */ TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis); +/******** Schedule: Buffer transformation ********/ +TVM_DLL void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int buffer_index); /******** Schedule: Block annotation ********/ /*! \brief The quad used by StorageAlign for (buffer_idx, axis, factor, offset) */ using StorageAlignTuple = Array; diff --git a/src/tir/schedule/primitive/rolling_buffer.cc b/src/tir/schedule/primitive/rolling_buffer.cc new file mode 100644 index 000000000000..8c93de75f581 --- /dev/null +++ b/src/tir/schedule/primitive/rolling_buffer.cc @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../ir_comparator.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +namespace { + +struct RollingBufferInfo { + Buffer old_buffer; + Buffer new_buffer; + int rolling_axis; + int rolling_extent; + std::vector axis_overlaps; + std::vector> axis_iter_vars; + /*! \brief The map used for ScheduleStateNode::Replace. */ + Map block_reuse; +}; + +BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const BufferRegion& buffer_region, + const Map& dom_map) { + Array relaxed_intsets = + arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), dom_map); + Region relaxed_region; + relaxed_region.reserve(relaxed_intsets.size()); + for (size_t i = 0; i < relaxed_intsets.size(); ++i) { + relaxed_region.push_back( + relaxed_intsets[i].CoverRange(Range::FromMinExtent(0, buffer_region->buffer->shape[i]))); + } + return BufferRegion(buffer_region->buffer, relaxed_region); +} + +class RollingBufferMatchError : public ScheduleError { + public: + RollingBufferMatchError(IRModule mod, Block block, BufferRegion buffer_region) + : mod_(mod), block_(block), buffer_region_(buffer_region) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer expect the buffer region to have at least one dimention" + "matching the rolling pattern such as: hh.outer * stride + hh.inner"; + } + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The target buffer " << buffer_region_->buffer->name << " with region " + << buffer_region_->region + << " should have at least one dimension range that matches a rolling pattern " + "such as hh.outer * stride + hh.inner. "; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + BufferRegion buffer_region_; +}; + +class RollingBufferInsertionError : public ScheduleError { + public: + RollingBufferInsertionError(IRModule mod, Buffer buffer, Block block) + : mod_(mod), buffer_(std::move(buffer)), block_(block) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer injection is invalid, the lca of the access " + "location of the target buffer is not a for loop. "; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "rolling_buffer injection is invalid. The block {0} should be tiled so that " + << "the lca of the access location of the target buffer " << buffer_->name + << " is a for loop. "; + return os.str(); + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Buffer buffer_; + Block block_; +}; + +class RollingBufferInfoCollector { + public: + static RollingBufferInfo CheckAndGetRollingBufferInfo(const IRModule& mod, + const StmtSRef& block_sref, + const BufferRegion& buffer_region) { + RollingBufferInfoCollector collector; + if (!collector.MatchRollingBuffer(block_sref, buffer_region)) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + throw RollingBufferMatchError(mod, GetRef(block), buffer_region); + } + return collector.info_; + } + + private: + bool MatchRollingBuffer(const StmtSRef& block_sref, const BufferRegion& buffer_region) { + const Buffer& buffer = buffer_region->buffer; + const Region& region = buffer_region->region; + + std::vector> bound_iter_vars; + std::vector bound_overlaps; + auto stride = 0; + auto divisor = 1; + Optional iter_var; + for (auto bound : region) { + divisor = 1; + if (auto floor_div = bound->min.as()) { + // Handle the case of fractional strides + // They take this form: floordiv(hh.outer, 2) + // Strip the floordiv and keep track of the divisor + divisor = Downcast(floor_div->b)->value; + bound = Range::FromMinExtent(floor_div->a, bound->extent, bound->span); + } + if (bound->min.as()) { + // If the bound is an int, we can't roll over it + iter_var = NullOpt; + } else if (auto var = bound->min.as()) { + // If the bound is just a Var, that implies the stride is 1 + iter_var = GetRef(var); + stride = 1; + } else { + // Otherwise, it's the iter var multiplied by the stride + // If not we're in unknown behaviour + if (auto mul = bound->min.as()) { + if (mul->a->IsInstance() && mul->b->IsInstance()) { + iter_var = Downcast(mul->a); + stride = Downcast(mul->b)->value; + } else { + return false; + } + } else { + return false; + } + } + stride = std::ceil(static_cast(stride) / divisor); + auto bound_overlap = 0; + if (iter_var.defined()) { + auto extent = bound->extent.as(); + ICHECK(extent); + bound_overlap = extent->value - stride; + // Since Pass CompactBufferAllocation will be responsible for compacting the buffer + // allocation region, there is no need to roll over the axis where the overlap is not + // positive, so reset iter_var to NullOpt. + if (bound_overlap <= 0) { + iter_var = NullOpt; + } + } + bound_iter_vars.push_back(iter_var); + bound_overlaps.push_back(bound_overlap); + } + Array loop_srefs = GetLoops(block_sref); + // Pick the outermost iter_var that's mentioned in the bounds + // to be the rolling axis + Optional roll_iter_var; + int roll_axis; + for (const tir::StmtSRef& loop_sref : loop_srefs) { + auto loop_var = loop_sref->StmtAs()->loop_var; + + auto it{std::find_if(bound_iter_vars.begin(), bound_iter_vars.end(), [&](Optional var) { + return var && (var.get() == loop_var.get()); + })}; + if (it != bound_iter_vars.end()) { + auto i = std::distance(bound_iter_vars.begin(), it); + roll_iter_var = loop_var; + roll_axis = i; + break; + } + } + + if (!roll_iter_var.defined()) { + return false; + } + Array new_shape = buffer->shape; + new_shape.Set(roll_axis, region[roll_axis]->extent); + Buffer new_buffer = buffer; + new_buffer.CopyOnWrite()->shape = new_shape; + + info_.old_buffer = buffer; + info_.new_buffer = new_buffer; + info_.rolling_axis = roll_axis; + info_.rolling_extent = static_cast(Downcast(region[roll_axis]->extent)->value); + info_.axis_overlaps = bound_overlaps; + info_.axis_iter_vars = bound_iter_vars; + + return true; + } + + RollingBufferInfo info_; +}; + +class RollingBufferRewriter : public StmtExprMutator { + public: + static Stmt Rewrite(const StmtSRef& scope_sref, RollingBufferInfo* info) { + RollingBufferRewriter rewriter(scope_sref, info); + return rewriter(GetRef(scope_sref->stmt)); + } + + private: + explicit RollingBufferRewriter(const StmtSRef& scope_sref, RollingBufferInfo* info) + : scope_sref_(scope_sref), info_(info) {} + + void RewriteAccessRegion(Array* old_access_regions, + const Array& infered_access_regions) { + auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) { + if (buffer_region->buffer.same_as(info_->old_buffer)) { + ICHECK(infered_access_regions.size() == 1); + return infered_access_regions[0]; + } + return buffer_region; + }; + (*old_access_regions).MutateByApply(fmutate); + } + + Stmt VisitStmt_(const BlockNode* block) final { + Block old_stmt = GetRef(block); + Block stmt = Downcast(StmtExprMutator::VisitStmt_(block)); + if (block == scope_sref_->stmt) { + ObjectPtr n = make_object(*stmt.as()); + + Array new_alloc_buffers; + for (const Buffer& buffer : stmt->alloc_buffers) { + if (buffer != info_->old_buffer) { + new_alloc_buffers.push_back(buffer); + } else { + new_alloc_buffers.push_back(info_->new_buffer); + } + } + n->alloc_buffers = std::move(new_alloc_buffers); + stmt = Block(n); + } else { + Array new_iter_bindings; + for (size_t i = 0; i < stmt->iter_vars.size(); ++i) { + auto old_iter_var = stmt->iter_vars[i]; + if (static_cast(i) == info_->rolling_axis) { + // All inner loops of the rolling axis has a loop carried dependency + // (i.e. each iteration calculation of the rolling axis depends on + // the calculation results of all the historical iterations of inner loops), + // so annotate the iteration type of the rolling axis as 'opaque', + // avoid the iterative range of its inner loop from being compressed + // during lowering phase. + IterVar new_iter_var = + IterVar(old_iter_var->dom, old_iter_var->var, IterVarType::kOpaque); + new_iter_bindings.push_back(new_iter_var); + } else { + new_iter_bindings.push_back(old_iter_var); + } + } + Map buffer_data_to_buffer = {{info_->new_buffer->data, info_->new_buffer}}; + auto infered_access_regions = GetBlockReadWriteRegion(stmt, buffer_data_to_buffer); + + BlockNode* n = stmt.CopyOnWrite(); + n->iter_vars = std::move(new_iter_bindings); + RewriteAccessRegion(&n->reads, infered_access_regions[0]); + RewriteAccessRegion(&n->writes, infered_access_regions[1]); + } + info_->block_reuse.Set(old_stmt, stmt); + return std::move(stmt); + } + + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + BlockRealize stmt = Downcast(StmtExprMutator::VisitStmt_(realize)); + // Append block predicate to avoid recomputing elements. + if (rewrite_block_predicate_) { + rewrite_block_predicate_ = false; + PrimExpr condition = stmt->predicate; + for (size_t i = 0; i < info_->axis_iter_vars.size(); ++i) { + auto iter_var = info_->axis_iter_vars[i]; + if (iter_var && info_->axis_overlaps[i] > 0) { + Var var = iter_var.value(); + const Map dmap = {std::make_pair(var, arith::IntSet::Interval(0, 0))}; + auto iter_value = realize->iter_values[i]; + arith::Analyzer analyzer; + auto term_2 = analyzer.int_set(iter_value, dmap).min(); + condition = analyzer.Simplify( + And(condition, Or(LT(var, 1), GE(term_2, info_->axis_overlaps[i])))); + } + } + BlockRealizeNode* n = stmt.CopyOnWrite(); + n->predicate = condition; + } + return std::move(stmt); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore stmt = Downcast(StmtExprMutator::VisitStmt_(op)); + if (stmt->buffer.same_as(info_->old_buffer)) { + Array new_indices; + new_indices.reserve(stmt->indices.size()); + // First modify the access indices to use modulo arithmetic + // for the rolling axis + for (size_t i = 0; i < stmt->indices.size(); ++i) { + auto index = stmt->indices[i]; + if (static_cast(i) == info_->rolling_axis) { + new_indices.push_back(FloorMod(index, info_->rolling_extent)); + } else { + new_indices.push_back(index); + } + } + BufferStoreNode* n = stmt.CopyOnWrite(); + // Replace the stored buffer with the new buffer. + n->buffer = info_->new_buffer; + n->indices = std::move(new_indices); + // Need to add predicate to the current block to avoid recomputing elements. + rewrite_block_predicate_ = true; + } + return std::move(stmt); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad stmt = Downcast(StmtExprMutator::VisitExpr_(op)); + if (stmt->buffer.same_as(info_->old_buffer)) { + Array new_indices; + new_indices.reserve(stmt->indices.size()); + for (size_t i{0}; i < stmt->indices.size(); ++i) { + auto index = stmt->indices[i]; + if (static_cast(i) == info_->rolling_axis) { + new_indices.push_back(FloorMod(index, info_->rolling_extent)); + } else { + new_indices.push_back(index); + } + } + BufferLoadNode* n = stmt.CopyOnWrite(); + // Replace the loaded buffer with the new buffer. + n->buffer = info_->new_buffer; + n->indices = std::move(new_indices); + } + return std::move(stmt); + } + + private: + const StmtSRef& scope_sref_; + RollingBufferInfo* info_; + bool rewrite_block_predicate_ = false; +}; + +} // namespace + +void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index) { + /*! + * Check + * - The block is not an output block. + * - The block is tiled and there is access overlap between adjacent tiles. + * Mutate + * - Select the outermost rollable axis appeared in the block's loop nest + * as the 'rolling axis', trim the target buffer from the rolling axis. + * - Use modulo arithmetic to modify the target buffer's read and load + * indices to circularize the buffer along the rolling dimension. + * - Append block predicate to avoid recomputing overlapping elements. + */ + Map dom_map; + const BlockRealize& realize = GetBlockRealize(self, block_sref); + const Block& block = realize->block; + + // Step 1. Checking index, getting the target buffer region and the parent scope. + const BufferRegion& buffer_region = + GetNthAccessBufferRegion(self, block, write_buffer_index, BufferIndexType::kWrite); + StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); + // Step 2. Check the target block is not an output block. + CheckNotOutputBlock(self, block_sref, scope_root_sref); + + // Step 3. Find the lca of the access location of the target buffer and relax the buffer + Array loop_srefs = GetLoops(block_sref); + Array consumers_sref = GetConsumers(self, block_sref); + consumers_sref.push_back(block_sref); + StmtSRef lca = GetSRefLowestCommonAncestor(consumers_sref); + if (!lca->StmtAs()) { + throw RollingBufferInsertionError(self->mod, buffer_region->buffer, block); + } + + for (auto it = loop_srefs.rbegin(); it != loop_srefs.rend(); ++it) { + auto stmt = *it; + // Stop at the lca of all the rolling_buffer access points; + if (stmt == lca) { + break; + } + For cur_loop = GetRef(stmt->StmtAs()); + Range range = Range::FromMinExtent(cur_loop->min, cur_loop->extent); + dom_map.Set(cur_loop->loop_var, arith::IntSet::FromRange(range)); + } + BufferRegion relaxed_region = GetRelaxedBufferRegion(realize, buffer_region, dom_map); + + // Step 4. Find an valid rolling axis and collect bound overlaps on the target buffer. + RollingBufferInfo info = RollingBufferInfoCollector::CheckAndGetRollingBufferInfo( + self->mod, block_sref, relaxed_region); + // Step 5. Mutate IR to apply rolling access pattern. + Stmt new_scope_root = RollingBufferRewriter::Rewrite(scope_root_sref, &info); + + // Step 6. Update schedule states + self->Replace(scope_root_sref, new_scope_root, info.block_reuse); + // Step 7. Regenerate block info from the root block, because `region_cover` for the target block + // and `stage_pipeline` for the root block are no longer satisfied after rolling buffer injection. + self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, self->stmt2ref.at(new_scope_root.get()))); +} + +struct RollingBufferTraits : public UnpackedInstTraits { + static constexpr const char* kName = "RollingBuffer"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer buffer_index) { + return sch->RollingBuffer(block, buffer_index.IntValue()); + } + + static String UnpackedAsPython(Array outputs, String block, Integer buffer_index) { + PythonAPICall py("rolling_buffer"); + py.Input("block", block); + py.Input("buffer_index", buffer_index); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(RollingBufferTraits); +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 280d0af92a8c..c50632f9026e 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -270,6 +270,9 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposePadding") .set_body_method(&ScheduleNode::DecomposePadding); TVM_REGISTER_GLOBAL("tir.schedule.SchedulePadEinsum") .set_body_method(&ScheduleNode::PadEinsum); +/******** (FFI) Buffer transformation ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRollingBuffer") + .set_body_method(&ScheduleNode::RollingBuffer); /******** (FFI) Misc ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc") .set_body_method(&ScheduleNode::EnterPostproc); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index b67b008feda4..3fa997376ac8 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -561,6 +561,18 @@ void TracedScheduleNode::PadEinsum(const BlockRV& block_rv, const Array /*outputs=*/{})); } +/******** Schedule: Buffer transformation ********/ + +void TracedScheduleNode::RollingBuffer(const BlockRV& block_rv, int buffer_index) { + ConcreteScheduleNode::RollingBuffer(block_rv, buffer_index); + static const InstructionKind& kind = InstructionKind::Get("RollingBuffer"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(buffer_index)}, + /*outputs=*/{})); +} + /******** Schedule: Misc ********/ void TracedScheduleNode::EnterPostproc() { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 016de60726b9..66dbd639cb01 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -113,6 +113,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { /******** Schedule: Padding ********/ BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) final; void PadEinsum(const BlockRV& block_rv, const Array& padding) final; + /******** Schedule: Buffer transformation ********/ + void RollingBuffer(const BlockRV& block_rv, int buffer_index) final; /******** Schedule: Misc ********/ void EnterPostproc() final; }; diff --git a/tests/python/unittest/test_tir_schedule_rolling_buffer.py b/tests/python/unittest/test_tir_schedule_rolling_buffer.py new file mode 100644 index 000000000000..ea87ef594023 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_rolling_buffer.py @@ -0,0 +1,534 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import numpy as np +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip +import pytest + + +def check_rolling_buffer( + sch: tir.Schedule, origin: tir.PrimFunc, expected: tir.PrimFunc, check_run=False +): + scheduled = sch.mod["main"] + tvm.ir.assert_structural_equal(scheduled, expected) + verify_trace_roundtrip(sch, origin) + if check_run: + in_buffer = origin.buffer_map[origin.params[0]] + out_buffer = origin.buffer_map[origin.params[1]] + in_shape = [int(_) for _ in in_buffer.shape] + out_shape = [int(_) for _ in out_buffer.shape] + x = tvm.nd.array(np.random.uniform(0, 64, in_shape).astype(in_buffer.dtype)) + y0 = tvm.nd.array(np.zeros(out_shape).astype(out_buffer.dtype)) + y1 = tvm.nd.array(np.zeros(out_shape).astype(out_buffer.dtype)) + f_origin = tvm.build(origin) + f_scheduled = tvm.build(scheduled) + f_origin(x, y0) + f_scheduled(x, y1) + tvm.testing.assert_allclose(y0.numpy(), y1.numpy()) + + +def _tile_nd(s, tile, block_name): + outer_indices = [] + inner_indices = [] + block = s.get_block(block_name) + loops = s.get_loops(block) + for i, size in enumerate(tile): + outer, inner = s.split(loops[i], [None, size]) + outer_indices.append(outer) + inner_indices.append(inner) + + s.reorder(*outer_indices, *inner_indices) + return outer_indices, inner_indices + + +def test_1d_rolling_buffer(): + @T.prim_func + def before(A: T.Buffer[(4, 12), "int32"], C: T.Buffer[(4, 8), "int32"]): + B = T.alloc_buffer((4, 10), "int32") + for c in T.serial(4): + for i in T.serial(0, 10): + for k in T.serial(3): + with T.block("B"): + cc, vi, vk = T.axis.remap("SSR", [c, i, k]) + with T.init(): + B[cc, vi] = 0 + B[cc, vi] = B[cc, vi] + A[cc, vi + vk] + for i in T.serial(0, 8): + for k in T.serial(3): + with T.block("C"): + cc, vi, vk = T.axis.remap("SSR", [c, i, k]) + with T.init(): + C[cc, vi] = 0 + C[cc, vi] = C[cc, vi] + B[cc, vi + vk] + + @T.prim_func + def expected(A: T.Buffer[(4, 12), "int32"], C: T.Buffer[(4, 8), "int32"]): + B = T.alloc_buffer([4, 6], dtype="int32") + for c, i_0 in T.grid(4, 2): + for ax0, ax1 in T.grid(6, 3): + with T.block("B"): + T.where(i_0 < 1 or 2 <= ax0) + cc = T.axis.spatial(4, c) + vi = T.axis.opaque(10, i_0 * 4 + ax0) + vk = T.axis.reduce(3, ax1) + T.reads(A[cc, vi + vk]) + T.writes(B[cc, vi % 6]) + with T.init(): + B[cc, vi % 6] = 0 + B[cc, vi % 6] = B[cc, vi % 6] + A[cc, vi + vk] + for i_1, k in T.grid(4, 3): + with T.block("C"): + cc = T.axis.spatial(4, c) + vi = T.axis.opaque(8, i_0 * 4 + i_1) + vk = T.axis.reduce(3, k) + T.reads(B[cc, (vi + vk) % 6]) + T.writes(C[cc, vi]) + with T.init(): + C[cc, vi] = 0 + C[cc, vi] = C[cc, vi] + B[cc, (vi + vk) % 6] + + sch = tir.Schedule(before, debug_mask="all") + _, i, _ = sch.get_loops(sch.get_block("C")) + io, _ = sch.split(i, [2, 4]) + sch.compute_at(sch.get_block("B"), io) + sch.rolling_buffer(sch.get_block("B"), 0) + check_rolling_buffer(sch, before, expected, check_run=True) + + +@T.prim_func +def cascade_2_max_pool2d(A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"]): + B = T.alloc_buffer([1, 10, 10, 16], dtype="int8") + for i0, i1, i2, i3, i4, i5 in T.grid(1, 10, 10, 16, 3, 3): + with T.block("B"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(A[ax0, ax1 + rv0, ax2 + rv1, ax3]) + T.writes(B[ax0, ax1, ax2, ax3]) + with T.init(): + B[ax0, ax1, ax2, ax3] = T.int8(-128) + B[ax0, ax1, ax2, ax3] = T.max(B[ax0, ax1, ax2, ax3], A[ax0, ax1 + rv0, ax2 + rv1, ax3]) + for i0, i1, i2, i3, i4, i5 in T.grid(1, 8, 8, 16, 3, 3): + with T.block("C"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(B[ax0, ax1 + rv0, ax2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max(C[ax0, ax1, ax2, ax3], B[ax0, ax1 + rv0, ax2 + rv1, ax3]) + + +@T.prim_func +def cascade_3_max_pool2d_with_stride( + A: T.Buffer[(1, 24, 24, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"] +): + B_0 = T.alloc_buffer([1, 22, 22, 16], dtype="int8") + B_1 = T.alloc_buffer([1, 10, 10, 16], dtype="int8") + for i0, i1, i2, i3, i4, i5 in T.grid(1, 22, 22, 16, 3, 3): + with T.block("B_0"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(A[ax0, ax1 + rv0, ax2 + rv1, ax3]) + T.writes(B_0[ax0, ax1, ax2, ax3]) + with T.init(): + B_0[ax0, ax1, ax2, ax3] = T.int8(-128) + B_0[ax0, ax1, ax2, ax3] = T.max( + B_0[ax0, ax1, ax2, ax3], A[ax0, ax1 + rv0, ax2 + rv1, ax3] + ) + for i0, i1, i2, i3, i4, i5 in T.grid(1, 10, 10, 16, 3, 3): + with T.block("B_1"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(B_0[ax0, ax1 * 2 + rv0, ax2 * 2 + rv1, ax3]) + T.writes(B_1[ax0, ax1, ax2, ax3]) + with T.init(): + B_1[ax0, ax1, ax2, ax3] = T.int8(-128) + B_1[ax0, ax1, ax2, ax3] = T.max( + B_1[ax0, ax1, ax2, ax3], B_0[ax0, ax1 * 2 + rv0, ax2 * 2 + rv1, ax3] + ) + for i0, i1, i2, i3, i4, i5 in T.grid(1, 8, 8, 16, 3, 3): + with T.block("C"): + ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(B_1[ax0, ax1 + rv0, ax2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B_1[ax0, ax1 + rv0, ax2 + rv1, ax3] + ) + + +def test_cascade_max_pool2d_w_tiled(): + @T.prim_func + def expected(A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"]): + B = T.alloc_buffer([1, 10, 6, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 1, 2, 1): + for ax0, ax1, ax2, ax3, ax4 in T.grid(10, 6, 16, 3, 3): + with T.block("B"): + T.where(i2_0 < 1 or 2 <= ax1) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.spatial(10, ax0) + ax2_1 = T.axis.opaque(10, i2_0 * 4 + ax1) + ax3_1, rv0, rv1 = T.axis.remap("SRR", [ax2, ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B[ax0_1, ax1_1, ax2_1 % 6, ax3_1]) + with T.init(): + B[ax0_1, ax1_1, ax2_1 % 6, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1, ax2_1 % 6, ax3_1] = T.max( + B[ax0_1, ax1_1, ax2_1 % 6, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 8, 4, 16, 3, 3): + with T.block("C"): + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.spatial(8, i1_0 * 8 + i1_1) + ax2 = T.axis.opaque(8, i2_0 * 4 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, ax1 + rv0, (ax2 + rv1) % 6, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, ax1 + rv0, (ax2 + rv1) % 6, ax3] + ) + + sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") + oi, _ = _tile_nd(sch, [1, 8, 4, 16], "C") + sch.compute_at(sch.get_block("B"), oi[-1]) + sch.rolling_buffer(sch.get_block("B"), 0) + check_rolling_buffer(sch, cascade_2_max_pool2d, expected, check_run=True) + + +def test_cascade_max_pool2d_h_tiled(): + @T.prim_func + def expected(A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"]): + B = T.alloc_buffer([1, 6, 10, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 2, 1, 1): + for ax0, ax1, ax2, ax3, ax4 in T.grid(6, 10, 16, 3, 3): + with T.block("B"): + T.where(i1_0 < 1 or 2 <= ax0) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.opaque(10, i1_0 * 4 + ax0) + ax2_1 = T.axis.spatial(10, ax1) + ax3_1, rv0, rv1 = T.axis.remap("SRR", [ax2, ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B[ax0_1, ax1_1 % 6, ax2_1, ax3_1]) + with T.init(): + B[ax0_1, ax1_1 % 6, ax2_1, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1 % 6, ax2_1, ax3_1] = T.max( + B[ax0_1, ax1_1 % 6, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 8, 16, 3, 3): + with T.block("C"): + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.opaque(8, i1_0 * 4 + i1_1) + ax2 = T.axis.spatial(8, i2_0 * 8 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, (ax1 + rv0) % 6, ax2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, (ax1 + rv0) % 6, ax2 + rv1, ax3] + ) + + sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") + io, _ = _tile_nd(sch, [1, 4, 8, 16], "C") + sch.compute_at(sch.get_block("B"), io[-1]) + sch.rolling_buffer(sch.get_block("B"), 0) + check_rolling_buffer(sch, cascade_2_max_pool2d, expected, check_run=True) + + +def test_cascade_max_pool2d_h_w_c_tiled(): + @T.prim_func + def expected(A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"]): + B = T.alloc_buffer([1, 6, 10, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 2, 2, 2): + for ax0, ax1, ax2, ax3, ax4 in T.grid(6, 6, 8, 3, 3): + with T.block("B"): + T.where((i1_0 < 1 or 2 <= ax0) and (i2_0 < 1 or 2 <= ax1)) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.opaque(10, i1_0 * 4 + ax0) + ax2_1 = T.axis.spatial(10, i2_0 * 4 + ax1) + ax3_1 = T.axis.spatial(16, i3_0 * 8 + ax2) + rv0, rv1 = T.axis.remap("RR", [ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B[ax0_1, ax1_1 % 6, ax2_1, ax3_1]) + with T.init(): + B[ax0_1, ax1_1 % 6, ax2_1, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1 % 6, ax2_1, ax3_1] = T.max( + B[ax0_1, ax1_1 % 6, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 4, 8, 3, 3): + with T.block("C"): + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.opaque(8, i1_0 * 4 + i1_1) + ax2 = T.axis.spatial(8, i2_0 * 4 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 8 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, (ax1 + rv0) % 6, ax2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, (ax1 + rv0) % 6, ax2 + rv1, ax3] + ) + + sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") + io, _ = _tile_nd(sch, [1, 4, 4, 8], "C") + sch.compute_at(sch.get_block("B"), io[-1]) + sch.rolling_buffer(sch.get_block("B"), 0) + check_rolling_buffer(sch, cascade_2_max_pool2d, expected, check_run=True) + + +def test_cascade_max_pool2d_non_perfect_tiled(): + @T.prim_func + def expected(A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"]) -> None: + B = T.alloc_buffer([1, 8, 10, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 2, 2, 1): + for ax0, ax1, ax2, ax3, ax4 in T.grid(8, 8, 16, 3, 3): + with T.block("B"): + T.where( + i1_0 * 6 + ax0 < 10 + and i2_0 * 6 + ax1 < 10 + and (i1_0 < 1 or 2 <= ax0) + and (i2_0 < 1 or 2 <= ax1) + ) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.opaque(10, i1_0 * 6 + ax0) + ax2_1 = T.axis.spatial(10, i2_0 * 6 + ax1) + ax3_1, rv0, rv1 = T.axis.remap("SRR", [ax2, ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B[ax0_1, ax1_1 % 8, ax2_1, ax3_1]) + with T.init(): + B[ax0_1, ax1_1 % 8, ax2_1, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1 % 8, ax2_1, ax3_1] = T.max( + B[ax0_1, ax1_1 % 8, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 6, 6, 16, 3, 3): + with T.block("C"): + T.where(i1_0 * 6 + i1_1 < 8 and i2_0 * 6 + i2_1 < 8) + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.opaque(8, i1_0 * 6 + i1_1) + ax2 = T.axis.spatial(8, i2_0 * 6 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, (ax1 + rv0) % 8, ax2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, (ax1 + rv0) % 8, ax2 + rv1, ax3] + ) + + sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") + io, _ = _tile_nd(sch, [1, 6, 6, 16], "C") + sch.compute_at(sch.get_block("B"), io[-1]) + sch.rolling_buffer(sch.get_block("B"), 0) + check_rolling_buffer(sch, cascade_2_max_pool2d, expected, check_run=True) + + +def test_cascade_3_max_pool2d_with_stride(): + @T.prim_func + def expected(A: T.Buffer[(1, 24, 24, 16), "int8"], C: T.Buffer[(1, 8, 8, 16), "int8"]) -> None: + B_0 = T.alloc_buffer([1, 13, 22, 16], dtype="int8") + B_1 = T.alloc_buffer([1, 6, 10, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 2, 2, 1): + for ax0, ax1, ax2, ax3, ax4 in T.grid(13, 13, 16, 3, 3): + with T.block("B_0"): + T.where((i1_0 < 1 or 5 <= ax0) and (i2_0 < 1 or 5 <= ax1)) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.opaque(22, i1_0 * 8 + ax0) + ax2_1 = T.axis.spatial(22, i2_0 * 8 + ax1) + ax3_1, rv0, rv1 = T.axis.remap("SRR", [ax2, ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B_0[ax0_1, ax1_1 % 13, ax2_1, ax3_1]) + with T.init(): + B_0[ax0_1, ax1_1 % 13, ax2_1, ax3_1] = T.int8(-128) + B_0[ax0_1, ax1_1 % 13, ax2_1, ax3_1] = T.max( + B_0[ax0_1, ax1_1 % 13, ax2_1, ax3_1], + A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1], + ) + for ax0, ax1, ax2, ax3, ax4 in T.grid(6, 6, 16, 3, 3): + with T.block("B_1"): + T.where((i1_0 < 1 or 2 <= ax0) and (i2_0 < 1 or 2 <= ax1)) + ax0_2 = T.axis.spatial(1, 0) + ax1_2 = T.axis.opaque(10, i1_0 * 4 + ax0) + ax2_2 = T.axis.spatial(10, i2_0 * 4 + ax1) + ax3_2, rv0, rv1 = T.axis.remap("SRR", [ax2, ax3, ax4]) + T.reads(B_0[ax0_2, (ax1_2 * 2 + rv0) % 13, ax2_2 * 2 + rv1, ax3_2]) + T.writes(B_1[ax0_2, ax1_2 % 6, ax2_2, ax3_2]) + with T.init(): + B_1[ax0_2, ax1_2 % 6, ax2_2, ax3_2] = T.int8(-128) + B_1[ax0_2, ax1_2 % 6, ax2_2, ax3_2] = T.max( + B_1[ax0_2, ax1_2 % 6, ax2_2, ax3_2], + B_0[ax0_2, (ax1_2 * 2 + rv0) % 13, ax2_2 * 2 + rv1, ax3_2], + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 4, 16, 3, 3): + with T.block("C"): + ax0_3 = T.axis.spatial(1, i0_0 + i0_1) + ax1_3 = T.axis.opaque(8, i1_0 * 4 + i1_1) + ax2_3 = T.axis.spatial(8, i2_0 * 4 + i2_1) + ax3_3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B_1[ax0_3, (ax1_3 + rv0) % 6, ax2_3 + rv1, ax3_3]) + T.writes(C[ax0_3, ax1_3, ax2_3, ax3_3]) + with T.init(): + C[ax0_3, ax1_3, ax2_3, ax3_3] = T.int8(-128) + C[ax0_3, ax1_3, ax2_3, ax3_3] = T.max( + C[ax0_3, ax1_3, ax2_3, ax3_3], + B_1[ax0_3, (ax1_3 + rv0) % 6, ax2_3 + rv1, ax3_3], + ) + + sch = tir.Schedule(cascade_3_max_pool2d_with_stride, debug_mask="all") + io, _ = _tile_nd(sch, [1, 4, 4, 16], "C") + sch.compute_at(sch.get_block("B_1"), io[-1]) + sch.compute_at(sch.get_block("B_0"), io[-1]) + sch.rolling_buffer(sch.get_block("B_0"), 0) + sch.rolling_buffer(sch.get_block("B_1"), 0) + check_rolling_buffer(sch, cascade_3_max_pool2d_with_stride, expected, check_run=True) + + +def test_upscale(): + @T.prim_func + def before(A: T.Buffer[(1, 16, 16, 16), "int8"], C: T.Buffer[(1, 24, 24, 16), "int8"]) -> None: + B = T.alloc_buffer([1, 14, 14, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 5, 5, 1): + for ax0, ax1, ax2, ax3, ax4 in T.grid(5, 5, 16, 3, 3): + with T.block("B"): + T.where(i1_0 * 5 // 2 + ax0 < 14 and i2_0 * 5 // 2 + ax1 < 14) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.spatial(14, i1_0 * 5 // 2 + ax0) + ax2_1 = T.axis.spatial(14, i2_0 * 5 // 2 + ax1) + ax3_1 = T.axis.spatial(16, ax2) + rv0, rv1 = T.axis.remap("RR", [ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B[ax0_1, ax1_1, ax2_1, ax3_1]) + with T.init(): + B[ax0_1, ax1_1, ax2_1, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1, ax2_1, ax3_1] = T.max( + B[ax0_1, ax1_1, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 5, 5, 16, 3, 3): + with T.block("C"): + T.where(i1_0 * 5 + i1_1 < 24 and i2_0 * 5 + i2_1 < 24) + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.spatial(24, i1_0 * 5 + i1_1) + ax2 = T.axis.spatial(24, i2_0 * 5 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, ax1 // 2 + rv0, ax2 // 2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, ax1 // 2 + rv0, ax2 // 2 + rv1, ax3] + ) + + @T.prim_func + def expected( + A: T.Buffer[(1, 16, 16, 16), "int8"], C: T.Buffer[(1, 24, 24, 16), "int8"] + ) -> None: + B = T.alloc_buffer([1, 5, 14, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 5, 5, 1): + for ax0, ax1, ax2, ax3, ax4 in T.grid(5, 5, 16, 3, 3): + with T.block("B"): + T.where( + i1_0 * 5 // 2 + ax0 < 14 + and i2_0 * 5 // 2 + ax1 < 14 + and (i1_0 < 1 or 2 <= ax0) + and (i2_0 < 1 or 2 <= ax1) + ) + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.opaque(14, i1_0 * 5 // 2 + ax0) + ax2_1 = T.axis.spatial(14, i2_0 * 5 // 2 + ax1) + ax3_1 = T.axis.spatial(16, ax2) + rv0, rv1 = T.axis.remap("RR", [ax3, ax4]) + T.reads(A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1]) + T.writes(B[ax0_1, ax1_1 % 5, ax2_1, ax3_1]) + with T.init(): + B[ax0_1, ax1_1 % 5, ax2_1, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1 % 5, ax2_1, ax3_1] = T.max( + B[ax0_1, ax1_1 % 5, ax2_1, ax3_1], A[ax0_1, ax1_1 + rv0, ax2_1 + rv1, ax3_1] + ) + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 5, 5, 16, 3, 3): + with T.block("C"): + T.where(i1_0 * 5 + i1_1 < 24 and i2_0 * 5 + i2_1 < 24) + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.opaque(24, i1_0 * 5 + i1_1) + ax2 = T.axis.spatial(24, i2_0 * 5 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, (ax1 // 2 + rv0) % 5, ax2 // 2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, (ax1 // 2 + rv0) % 5, ax2 // 2 + rv1, ax3] + ) + + sch = tir.Schedule(before, debug_mask="all") + sch.rolling_buffer(sch.get_block("B"), 0) + check_rolling_buffer(sch, before, expected, check_run=True) + + +def test_rolling_buffer_match_fail(): + @T.prim_func + def func_non_overlap( + A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 12, 12, 16), "int8"] + ): + B = T.alloc_buffer([1, 12, 12, 16], dtype="int8") + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 3, 3, 1): + for ax0, ax1, ax2 in T.grid(4, 4, 16): + with T.block("B"): + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.spatial(12, i1_0 * 4 + ax0) + ax2_1 = T.axis.spatial(12, i2_0 * 4 + ax1) + ax3 = T.axis.spatial(16, ax2) + T.reads(A[ax0_1, ax1_1, ax2_1, ax3]) + T.writes(B[ax0_1, ax1_1, ax2_1, ax3]) + with T.init(): + B[ax0_1, ax1_1, ax2_1, ax3] = T.int8(-128) + B[ax0_1, ax1_1, ax2_1, ax3] = A[ax0_1, ax1_1, ax2_1, ax3] + for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 4, 16, 1, 1): + with T.block("C"): + ax0 = T.axis.spatial(1, i0_0 + i0_1) + ax1 = T.axis.spatial(12, i1_0 * 4 + i1_1) + ax2 = T.axis.spatial(12, i2_0 * 4 + i2_1) + ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1) + rv0, rv1 = T.axis.remap("RR", [i4, i5]) + T.reads(B[ax0, ax1 + rv0, ax2 + rv1, ax3]) + T.writes(C[ax0, ax1, ax2, ax3]) + with T.init(): + C[ax0, ax1, ax2, ax3] = T.int8(-128) + C[ax0, ax1, ax2, ax3] = T.max( + C[ax0, ax1, ax2, ax3], B[ax0, ax1 + rv0, ax2 + rv1, ax3] + ) + + sch = tir.Schedule(func_non_overlap, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.rolling_buffer(sch.get_block("B"), 0) + + +def test_rolling_buffer_injection_invalid(): + sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") + # Block B is not compute_at to Block C, so rolling_buffer injection is invalid. + _, _ = _tile_nd(sch, [1, 4, 8, 16], "C") + _, _ = _tile_nd(sch, [1, 4, 8, 16], "B") + with pytest.raises(tvm.tir.ScheduleError): + sch.rolling_buffer(sch.get_block("B"), 0) From b51b70454c0c247a8075c4329dff95c43637c08f Mon Sep 17 00:00:00 2001 From: LiangW <732811423@qq.com> Date: Tue, 18 Oct 2022 08:46:18 +0000 Subject: [PATCH 2/3] Address review comments --- include/tvm/tir/schedule/schedule.h | 2 +- python/tvm/tir/schedule/schedule.py | 8 +- src/tir/schedule/concrete_schedule.cc | 4 +- src/tir/schedule/concrete_schedule.h | 2 +- src/tir/schedule/primitive.h | 2 +- src/tir/schedule/primitive/rolling_buffer.cc | 128 ++++++++---------- src/tir/schedule/traced_schedule.cc | 6 +- src/tir/schedule/traced_schedule.h | 2 +- .../test_tir_schedule_rolling_buffer.py | 4 + 9 files changed, 73 insertions(+), 85 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index b39491eacd8a..4dfda5313ab2 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -682,7 +682,7 @@ class ScheduleNode : public runtime::Object { virtual void PadEinsum(const BlockRV& block_rv, const Array& padding) = 0; /******** Schedule: Buffer transformation ********/ - virtual void RollingBuffer(const BlockRV& block_rv, int buffer_index) = 0; + virtual void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) = 0; /******** Schedule: Misc ********/ /*! \brief A no-op that marks the start of postprocessing phase of scheduling */ diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index ad32cb46a9cb..4c7e199d48be 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -3048,7 +3048,7 @@ def after_pad_einsum( def rolling_buffer( self, block: Union[BlockRV, str], - buffer_index: int, + write_buffer_index: int, ) -> None: """Compute the target buffer via rolling buffering, select the outermost rollable axis with a positive bound overlap that appears in the block's ancestor loops @@ -3066,7 +3066,7 @@ def rolling_buffer( ---------- block : Union[BlockRV, str] The producer block of the buffer. - buffer_index : int + write_buffer_index : int The index of the buffer in block's write region. Examples @@ -3106,7 +3106,7 @@ def before_rolling_buffer( .. code-block:: python sch = tir.Schedule(before_rolling_buffer) - sch.rolling_buffer(sch.get_block("B"), buffer_index=0) + sch.rolling_buffer(sch.get_block("B"), write_buffer_index=0) print(sch.mod["main"].script()) After applying rolling_buffer, the IR becomes: @@ -3145,7 +3145,7 @@ def after_rolling_buffer( The region_cover property of the consumer block of the target buffer will become false. """ block = self._normalize_block_arg(block) - return _ffi_api.ScheduleRollingBuffer(self, block, buffer_index) # type: ignore # pylint: disable=no-member + return _ffi_api.ScheduleRollingBuffer(self, block, write_buffer_index) # type: ignore # pylint: disable=no-member ########## Schedule: Misc ########## diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index ae500fa563fb..dcb6d2d06b89 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -640,9 +640,9 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { /******** Schedule: Buffer Transformation ********/ -void ConcreteScheduleNode::RollingBuffer(const BlockRV& block_rv, int buffer_index) { +void ConcreteScheduleNode::RollingBuffer(const BlockRV& block_rv, int write_buffer_index) { TVM_TIR_SCHEDULE_BEGIN(); - tir::RollingBuffer(state_, this->GetSRef(block_rv), buffer_index); + tir::RollingBuffer(state_, this->GetSRef(block_rv), write_buffer_index); TVM_TIR_SCHEDULE_END("rolling-buffer", this->error_render_level_); this->state_->DebugVerify(); } diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index e3eddd7f8395..9e001b139751 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -154,7 +154,7 @@ class ConcreteScheduleNode : public ScheduleNode { /******** Schedule: Padding decomposition ********/ BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) override; /******** Schedule: Buffer transformation ********/ - void RollingBuffer(const BlockRV& block_rv, int buffer_index) override; + void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) override; /******** Schedule: Misc ********/ void EnterPostproc() override {} diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index d5ce299ff036..51aed8bac21a 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -396,7 +396,7 @@ TVM_DLL StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sr */ TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis); /******** Schedule: Buffer transformation ********/ -TVM_DLL void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int buffer_index); +TVM_DLL void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index); /******** Schedule: Block annotation ********/ /*! \brief The quad used by StorageAlign for (buffer_idx, axis, factor, offset) */ using StorageAlignTuple = Array; diff --git a/src/tir/schedule/primitive/rolling_buffer.cc b/src/tir/schedule/primitive/rolling_buffer.cc index 8c93de75f581..57274c840f5c 100644 --- a/src/tir/schedule/primitive/rolling_buffer.cc +++ b/src/tir/schedule/primitive/rolling_buffer.cc @@ -30,7 +30,7 @@ struct RollingBufferInfo { Buffer old_buffer; Buffer new_buffer; int rolling_axis; - int rolling_extent; + PrimExpr rolling_extent; std::vector axis_overlaps; std::vector> axis_iter_vars; /*! \brief The map used for ScheduleStateNode::Replace. */ @@ -121,45 +121,40 @@ class RollingBufferInfoCollector { std::vector> bound_iter_vars; std::vector bound_overlaps; - auto stride = 0; - auto divisor = 1; - Optional iter_var; + + arith::PVar p_var; + arith::PVar p_stride, p_divisor; for (auto bound : region) { - divisor = 1; - if (auto floor_div = bound->min.as()) { + auto stride = 0; + auto divisor = 1; + + Optional iter_var; + if (floordiv((p_var * p_stride), p_divisor).Match(bound->min)) { // Handle the case of fractional strides // They take this form: floordiv(hh.outer, 2) // Strip the floordiv and keep track of the divisor - divisor = Downcast(floor_div->b)->value; - bound = Range::FromMinExtent(floor_div->a, bound->extent, bound->span); - } - if (bound->min.as()) { - // If the bound is an int, we can't roll over it - iter_var = NullOpt; - } else if (auto var = bound->min.as()) { + iter_var = p_var.Eval(); + divisor = p_divisor.Eval()->value; + stride = std::ceil(static_cast(p_stride.Eval()->value) / divisor); + } else if ((p_var * p_stride).Match(bound->min)) { + // The bound is the iter var multiplied by the stride + iter_var = p_var.Eval(); + stride = p_stride.Eval()->value; + } else if (p_var.Match(bound->min)) { // If the bound is just a Var, that implies the stride is 1 - iter_var = GetRef(var); + iter_var = p_var.Eval(); stride = 1; + } else if (is_const_int(bound->min)) { + // If the bound is an int, we can't roll over it + iter_var = NullOpt; } else { - // Otherwise, it's the iter var multiplied by the stride - // If not we're in unknown behaviour - if (auto mul = bound->min.as()) { - if (mul->a->IsInstance() && mul->b->IsInstance()) { - iter_var = Downcast(mul->a); - stride = Downcast(mul->b)->value; - } else { - return false; - } - } else { - return false; - } + // If all of the above matches fail, we're in unknown behaviour + return false; } - stride = std::ceil(static_cast(stride) / divisor); auto bound_overlap = 0; if (iter_var.defined()) { - auto extent = bound->extent.as(); - ICHECK(extent); - bound_overlap = extent->value - stride; + auto extent = Downcast(bound->extent)->value; + bound_overlap = extent - stride; // Since Pass CompactBufferAllocation will be responsible for compacting the buffer // allocation region, there is no need to roll over the axis where the overlap is not // positive, so reset iter_var to NullOpt. @@ -170,6 +165,7 @@ class RollingBufferInfoCollector { bound_iter_vars.push_back(iter_var); bound_overlaps.push_back(bound_overlap); } + Array loop_srefs = GetLoops(block_sref); // Pick the outermost iter_var that's mentioned in the bounds // to be the rolling axis @@ -200,7 +196,7 @@ class RollingBufferInfoCollector { info_.old_buffer = buffer; info_.new_buffer = new_buffer; info_.rolling_axis = roll_axis; - info_.rolling_extent = static_cast(Downcast(region[roll_axis]->extent)->value); + info_.rolling_extent = region[roll_axis]->extent; info_.axis_overlaps = bound_overlaps; info_.axis_iter_vars = bound_iter_vars; @@ -233,12 +229,28 @@ class RollingBufferRewriter : public StmtExprMutator { (*old_access_regions).MutateByApply(fmutate); } + void RewriteBufferAccess(Buffer* buffer, Array* indices) const { + Array new_indices; + new_indices.reserve(indices->size()); + // First modify the access indices to use modulo arithmetic + // for the rolling axis + for (size_t i = 0; i < indices->size(); ++i) { + if (static_cast(i) == info_->rolling_axis) { + new_indices.push_back(FloorMod((*indices)[i], info_->rolling_extent)); + } else { + new_indices.push_back((*indices)[i]); + } + } + // Replace the accessed buffer with the new buffer. + *buffer = info_->new_buffer; + *indices = std::move(new_indices); + } + Stmt VisitStmt_(const BlockNode* block) final { Block old_stmt = GetRef(block); Block stmt = Downcast(StmtExprMutator::VisitStmt_(block)); + BlockNode* n = stmt.CopyOnWrite(); if (block == scope_sref_->stmt) { - ObjectPtr n = make_object(*stmt.as()); - Array new_alloc_buffers; for (const Buffer& buffer : stmt->alloc_buffers) { if (buffer != info_->old_buffer) { @@ -248,9 +260,8 @@ class RollingBufferRewriter : public StmtExprMutator { } } n->alloc_buffers = std::move(new_alloc_buffers); - stmt = Block(n); } else { - Array new_iter_bindings; + Array new_iter_vars; for (size_t i = 0; i < stmt->iter_vars.size(); ++i) { auto old_iter_var = stmt->iter_vars[i]; if (static_cast(i) == info_->rolling_axis) { @@ -262,16 +273,15 @@ class RollingBufferRewriter : public StmtExprMutator { // during lowering phase. IterVar new_iter_var = IterVar(old_iter_var->dom, old_iter_var->var, IterVarType::kOpaque); - new_iter_bindings.push_back(new_iter_var); + new_iter_vars.push_back(new_iter_var); } else { - new_iter_bindings.push_back(old_iter_var); + new_iter_vars.push_back(old_iter_var); } } Map buffer_data_to_buffer = {{info_->new_buffer->data, info_->new_buffer}}; auto infered_access_regions = GetBlockReadWriteRegion(stmt, buffer_data_to_buffer); - BlockNode* n = stmt.CopyOnWrite(); - n->iter_vars = std::move(new_iter_bindings); + n->iter_vars = std::move(new_iter_vars); RewriteAccessRegion(&n->reads, infered_access_regions[0]); RewriteAccessRegion(&n->writes, infered_access_regions[1]); } @@ -306,22 +316,8 @@ class RollingBufferRewriter : public StmtExprMutator { Stmt VisitStmt_(const BufferStoreNode* op) final { BufferStore stmt = Downcast(StmtExprMutator::VisitStmt_(op)); if (stmt->buffer.same_as(info_->old_buffer)) { - Array new_indices; - new_indices.reserve(stmt->indices.size()); - // First modify the access indices to use modulo arithmetic - // for the rolling axis - for (size_t i = 0; i < stmt->indices.size(); ++i) { - auto index = stmt->indices[i]; - if (static_cast(i) == info_->rolling_axis) { - new_indices.push_back(FloorMod(index, info_->rolling_extent)); - } else { - new_indices.push_back(index); - } - } BufferStoreNode* n = stmt.CopyOnWrite(); - // Replace the stored buffer with the new buffer. - n->buffer = info_->new_buffer; - n->indices = std::move(new_indices); + RewriteBufferAccess(&n->buffer, &n->indices); // Need to add predicate to the current block to avoid recomputing elements. rewrite_block_predicate_ = true; } @@ -331,20 +327,8 @@ class RollingBufferRewriter : public StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* op) final { BufferLoad stmt = Downcast(StmtExprMutator::VisitExpr_(op)); if (stmt->buffer.same_as(info_->old_buffer)) { - Array new_indices; - new_indices.reserve(stmt->indices.size()); - for (size_t i{0}; i < stmt->indices.size(); ++i) { - auto index = stmt->indices[i]; - if (static_cast(i) == info_->rolling_axis) { - new_indices.push_back(FloorMod(index, info_->rolling_extent)); - } else { - new_indices.push_back(index); - } - } BufferLoadNode* n = stmt.CopyOnWrite(); - // Replace the loaded buffer with the new buffer. - n->buffer = info_->new_buffer; - n->indices = std::move(new_indices); + RewriteBufferAccess(&n->buffer, &n->indices); } return std::move(stmt); } @@ -401,7 +385,7 @@ void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buf } BufferRegion relaxed_region = GetRelaxedBufferRegion(realize, buffer_region, dom_map); - // Step 4. Find an valid rolling axis and collect bound overlaps on the target buffer. + // Step 4. Find a valid rolling axis and collect bound overlaps on the target buffer. RollingBufferInfo info = RollingBufferInfoCollector::CheckAndGetRollingBufferInfo( self->mod, block_sref, relaxed_region); // Step 5. Mutate IR to apply rolling access pattern. @@ -423,14 +407,14 @@ struct RollingBufferTraits : public UnpackedInstTraits { static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; - static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer buffer_index) { - return sch->RollingBuffer(block, buffer_index.IntValue()); + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer write_buffer_index) { + return sch->RollingBuffer(block, write_buffer_index.IntValue()); } - static String UnpackedAsPython(Array outputs, String block, Integer buffer_index) { + static String UnpackedAsPython(Array outputs, String block, Integer write_buffer_index) { PythonAPICall py("rolling_buffer"); py.Input("block", block); - py.Input("buffer_index", buffer_index); + py.Input("write_buffer_index", write_buffer_index); return py.Str(); } diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 3fa997376ac8..36bfbc247183 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -563,13 +563,13 @@ void TracedScheduleNode::PadEinsum(const BlockRV& block_rv, const Array /******** Schedule: Buffer transformation ********/ -void TracedScheduleNode::RollingBuffer(const BlockRV& block_rv, int buffer_index) { - ConcreteScheduleNode::RollingBuffer(block_rv, buffer_index); +void TracedScheduleNode::RollingBuffer(const BlockRV& block_rv, int write_buffer_index) { + ConcreteScheduleNode::RollingBuffer(block_rv, write_buffer_index); static const InstructionKind& kind = InstructionKind::Get("RollingBuffer"); trace_->Append(/*inst=*/Instruction( /*kind=*/kind, /*inputs=*/{block_rv}, - /*attrs=*/{Integer(buffer_index)}, + /*attrs=*/{Integer(write_buffer_index)}, /*outputs=*/{})); } diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 66dbd639cb01..450c19326a46 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -114,7 +114,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) final; void PadEinsum(const BlockRV& block_rv, const Array& padding) final; /******** Schedule: Buffer transformation ********/ - void RollingBuffer(const BlockRV& block_rv, int buffer_index) final; + void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) final; /******** Schedule: Misc ********/ void EnterPostproc() final; }; diff --git a/tests/python/unittest/test_tir_schedule_rolling_buffer.py b/tests/python/unittest/test_tir_schedule_rolling_buffer.py index ea87ef594023..edda7799801a 100644 --- a/tests/python/unittest/test_tir_schedule_rolling_buffer.py +++ b/tests/python/unittest/test_tir_schedule_rolling_buffer.py @@ -532,3 +532,7 @@ def test_rolling_buffer_injection_invalid(): _, _ = _tile_nd(sch, [1, 4, 8, 16], "B") with pytest.raises(tvm.tir.ScheduleError): sch.rolling_buffer(sch.get_block("B"), 0) + + +if __name__ == "__main__": + tvm.testing.main() From 46fbf36423ce1c4f64a806f2ad8d13bf180e1742 Mon Sep 17 00:00:00 2001 From: LiangW <732811423@qq.com> Date: Fri, 21 Oct 2022 09:56:43 +0000 Subject: [PATCH 3/3] Add dependency checks --- include/tvm/tir/schedule/schedule.h | 14 +++++ python/tvm/tir/schedule/schedule.py | 11 ++-- src/tir/schedule/concrete_schedule.cc | 21 ++++--- src/tir/schedule/primitive.h | 18 +++++- src/tir/schedule/primitive/rolling_buffer.cc | 49 ++++++++++++++- .../test_tir_schedule_rolling_buffer.py | 59 +++++++++++++++---- 6 files changed, 144 insertions(+), 28 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 4dfda5313ab2..9ba6f1a311de 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -682,6 +682,20 @@ class ScheduleNode : public runtime::Object { virtual void PadEinsum(const BlockRV& block_rv, const Array& padding) = 0; /******** Schedule: Buffer transformation ********/ + /*! + * \brief Compute the target buffer via rolling buffering. + * \details This primitive selects the outermost rollable axis with a positive bound overlap that + * appears in the block's ancestor loops as `rolling axis`, fold and circularize the buffer along + * the rolling dimension, append block predicate to avoid recomputing overlapping elements. + * It requires: + * 1) The buffer to be an intermediate buffer defined via `alloc_buffer`. + * 2) The LCA of the producer and consumer of the buffer is a for loop, typically, + * the producer and consumer of the buffer are cascaded through compute_at. + * 3) The access region of the buffer has at least one dimension that contains + * a positive bound overlap. + * \param block_rv The producer block of the buffer. + * \param write_buffer_index The index of the buffer in block's write region. + */ virtual void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) = 0; /******** Schedule: Misc ********/ diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 4c7e199d48be..832ca1d2e2e0 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -3052,14 +3052,17 @@ def rolling_buffer( ) -> None: """Compute the target buffer via rolling buffering, select the outermost rollable axis with a positive bound overlap that appears in the block's ancestor loops - as `rolling axis`. It requires: + as `rolling axis`, fold and circularize the buffer along the rolling dimension, + append block predicate to avoid recomputing overlapping elements. It requires: - 1) The buffer to be an intermediate buffer defined via `alloc_buffer`. + 1) The block is not an output block and has only RAW dependencies. - 2) The LCA of the producer and consumer of the buffer is a for loop, typically, + 2) The buffer to be an intermediate buffer defined via `alloc_buffer`. + + 3) The LCA of the producer and consumer of the buffer is a for loop, typically, the producer and consumer of the buffer are cascaded through compute_at. - 3) The access region of the buffer has at least one dimension that contains + 4) The access region of the buffer has at least one dimension that contains a positive bound overlap. Parameters diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index dcb6d2d06b89..5a7e3b15aea4 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -638,15 +638,6 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { this->state_->DebugVerify(); } -/******** Schedule: Buffer Transformation ********/ - -void ConcreteScheduleNode::RollingBuffer(const BlockRV& block_rv, int write_buffer_index) { - TVM_TIR_SCHEDULE_BEGIN(); - tir::RollingBuffer(state_, this->GetSRef(block_rv), write_buffer_index); - TVM_TIR_SCHEDULE_END("rolling-buffer", this->error_render_level_); - this->state_->DebugVerify(); -} - /******** Schedule: Block Annotation ********/ void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, @@ -810,6 +801,8 @@ void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_ this->state_->DebugVerify(); } +/******** Schedule: Padding ********/ + BlockRV ConcreteScheduleNode::DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); @@ -825,6 +818,16 @@ void ConcreteScheduleNode::PadEinsum(const BlockRV& block_rv, const Arrayerror_render_level_); this->state_->DebugVerify(); } + +/******** Schedule: Buffer Transformation ********/ + +void ConcreteScheduleNode::RollingBuffer(const BlockRV& block_rv, int write_buffer_index) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::RollingBuffer(state_, this->GetSRef(block_rv), write_buffer_index); + TVM_TIR_SCHEDULE_END("rolling-buffer", this->error_render_level_); + this->state_->DebugVerify(); +} + /******** Schedule: Misc ********/ } // namespace tir diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 51aed8bac21a..02ca0dcab432 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -395,8 +395,6 @@ TVM_DLL StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sr * \return The sref of the rfactor block */ TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis); -/******** Schedule: Buffer transformation ********/ -TVM_DLL void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index); /******** Schedule: Block annotation ********/ /*! \brief The quad used by StorageAlign for (buffer_idx, axis, factor, offset) */ using StorageAlignTuple = Array; @@ -526,6 +524,22 @@ TVM_DLL StmtSRef DecomposePadding(ScheduleState self, const StmtSRef& block_sref TVM_DLL void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const Array& padding); +/******** Schedule: Buffer transformation ********/ +/*! + * \brief Compute the target buffer via rolling buffering. + * \details This primitive selects the outermost rollable axis with a positive bound overlap that + * appears in the block's ancestor loops as `rolling axis`, fold and circularize the buffer along + * the rolling dimension, append block predicate to avoid recomputing overlapping elements. + * It requires: + * 1) The buffer to be an intermediate buffer defined via `alloc_buffer`. + * 2) The LCA of the producer and consumer of the buffer is a for loop, typically, + * the producer and consumer of the buffer are cascaded through compute_at. + * 3) The access region of the buffer has at least one dimension that contains + * a positive bound overlap. + * \param block_rv The producer block of the buffer. + * \param write_buffer_index The index of the buffer in block's write region. + */ +TVM_DLL void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index); /******** Schedule: Misc ********/ } // namespace tir diff --git a/src/tir/schedule/primitive/rolling_buffer.cc b/src/tir/schedule/primitive/rolling_buffer.cc index 57274c840f5c..c01d6c568fcd 100644 --- a/src/tir/schedule/primitive/rolling_buffer.cc +++ b/src/tir/schedule/primitive/rolling_buffer.cc @@ -50,6 +50,51 @@ BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const BufferReg return BufferRegion(buffer_region->buffer, relaxed_region); } +class RollingBufferDependencyError : public ScheduleError { + public: + explicit RollingBufferDependencyError(IRModule mod, Block block) + : mod_(mod), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: The target block is required to have only RAW dependencies"; + } + + String DetailRenderTemplate() const final { + return "The target block {0} is required to have only RAW dependencies"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + /*! + * \brief Check if the block has only RAW dependencies. + * \param self The schedule state + * \param block_sref The sref of the block to be checked + * \param scope_root_sref The sref of the scope root + * \throw ScheduleError if the block has WAW or WAR dependency. + */ + static void Check(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + BlockScope scope = self->GetBlockScope(scope_root_sref); + for (const Dependency& producers : scope->GetDepsByDst(block_sref)) { + if (!(producers->kind == DepKind::kRAW)) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + throw RollingBufferDependencyError(self->mod, GetRef(block)); + } + } + for (const Dependency& consumers : scope->GetDepsBySrc(block_sref)) { + if (!(consumers->kind == DepKind::kRAW)) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + throw RollingBufferDependencyError(self->mod, GetRef(block)); + } + } + } + + private: + IRModule mod_; + Block block_; +}; + class RollingBufferMatchError : public ScheduleError { public: RollingBufferMatchError(IRModule mod, Block block, BufferRegion buffer_region) @@ -345,6 +390,7 @@ void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buf /*! * Check * - The block is not an output block. + * - The block has only RAW dependencies. * - The block is tiled and there is access overlap between adjacent tiles. * Mutate * - Select the outermost rollable axis appeared in the block's loop nest @@ -361,8 +407,9 @@ void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buf const BufferRegion& buffer_region = GetNthAccessBufferRegion(self, block, write_buffer_index, BufferIndexType::kWrite); StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); - // Step 2. Check the target block is not an output block. + // Step 2. Check if the target block is not an output block and has only RAW dependencies. CheckNotOutputBlock(self, block_sref, scope_root_sref); + RollingBufferDependencyError::Check(self, block_sref, scope_root_sref); // Step 3. Find the lca of the access location of the target buffer and relax the buffer Array loop_srefs = GetLoops(block_sref); diff --git a/tests/python/unittest/test_tir_schedule_rolling_buffer.py b/tests/python/unittest/test_tir_schedule_rolling_buffer.py index edda7799801a..c55c41e451cc 100644 --- a/tests/python/unittest/test_tir_schedule_rolling_buffer.py +++ b/tests/python/unittest/test_tir_schedule_rolling_buffer.py @@ -119,16 +119,12 @@ def cascade_2_max_pool2d(A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 8 for i0, i1, i2, i3, i4, i5 in T.grid(1, 10, 10, 16, 3, 3): with T.block("B"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) - T.reads(A[ax0, ax1 + rv0, ax2 + rv1, ax3]) - T.writes(B[ax0, ax1, ax2, ax3]) with T.init(): B[ax0, ax1, ax2, ax3] = T.int8(-128) B[ax0, ax1, ax2, ax3] = T.max(B[ax0, ax1, ax2, ax3], A[ax0, ax1 + rv0, ax2 + rv1, ax3]) for i0, i1, i2, i3, i4, i5 in T.grid(1, 8, 8, 16, 3, 3): with T.block("C"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) - T.reads(B[ax0, ax1 + rv0, ax2 + rv1, ax3]) - T.writes(C[ax0, ax1, ax2, ax3]) with T.init(): C[ax0, ax1, ax2, ax3] = T.int8(-128) C[ax0, ax1, ax2, ax3] = T.max(C[ax0, ax1, ax2, ax3], B[ax0, ax1 + rv0, ax2 + rv1, ax3]) @@ -143,8 +139,6 @@ def cascade_3_max_pool2d_with_stride( for i0, i1, i2, i3, i4, i5 in T.grid(1, 22, 22, 16, 3, 3): with T.block("B_0"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) - T.reads(A[ax0, ax1 + rv0, ax2 + rv1, ax3]) - T.writes(B_0[ax0, ax1, ax2, ax3]) with T.init(): B_0[ax0, ax1, ax2, ax3] = T.int8(-128) B_0[ax0, ax1, ax2, ax3] = T.max( @@ -153,8 +147,6 @@ def cascade_3_max_pool2d_with_stride( for i0, i1, i2, i3, i4, i5 in T.grid(1, 10, 10, 16, 3, 3): with T.block("B_1"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) - T.reads(B_0[ax0, ax1 * 2 + rv0, ax2 * 2 + rv1, ax3]) - T.writes(B_1[ax0, ax1, ax2, ax3]) with T.init(): B_1[ax0, ax1, ax2, ax3] = T.int8(-128) B_1[ax0, ax1, ax2, ax3] = T.max( @@ -163,8 +155,6 @@ def cascade_3_max_pool2d_with_stride( for i0, i1, i2, i3, i4, i5 in T.grid(1, 8, 8, 16, 3, 3): with T.block("C"): ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) - T.reads(B_1[ax0, ax1 + rv0, ax2 + rv1, ax3]) - T.writes(C[ax0, ax1, ax2, ax3]) with T.init(): C[ax0, ax1, ax2, ax3] = T.int8(-128) C[ax0, ax1, ax2, ax3] = T.max( @@ -487,7 +477,52 @@ def expected( check_rolling_buffer(sch, before, expected, check_run=True) -def test_rolling_buffer_match_fail(): +def test_fail_rolling_buffer_multi_writers(): + @T.prim_func + def func_multi_writers( + A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 12, 12, 16), "int8"] + ): + B = T.alloc_buffer([1, 12, 12, 16], dtype="int8") + for i0, i1, i2, i3 in T.grid(1, 3, 3, 1): + for ax0, ax1, ax2 in T.grid(6, 6, 16): + with T.block("B_writer_0"): + ax0_1 = T.axis.spatial(1, i0) + ax1_1 = T.axis.spatial(12, i1 * 4 + ax0) + ax2_1 = T.axis.spatial(12, i2 * 4 + ax1) + ax3_1 = T.axis.spatial(16, ax2) + with T.init(): + B[ax0_1, ax1_1, ax2_1, ax3_1] = T.int8(-128) + B[ax0_1, ax1_1, ax2_1, ax3_1] = A[ax0_1, ax1_1, ax2_1, ax3_1] + T.int8(1) + for ax0, ax1, ax2 in T.grid(6, 6, 16): + with T.block("B_writer_1"): + ax0_2 = T.axis.spatial(1, i0) + ax1_2 = T.axis.spatial(12, i1 * 4 + ax0) + ax2_2 = T.axis.spatial(12, i2 * 4 + ax1) + ax3_2 = T.axis.spatial(16, ax2) + with T.init(): + B[ax0_2, ax1_2, ax2_2, ax3_2] = T.int8(-128) + B[ax0_2, ax1_2, ax2_2, ax3_2] = B[ax0_2, ax1_2, ax2_2, ax3_2] + A[ + ax0_2, ax1_2, ax2_2, ax3_2 + ] * T.int8(2) + for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(1, 4, 4, 16, 3, 3): + with T.block("C"): + ax0_3 = T.axis.spatial(1, i0 + ax0) + ax1_3 = T.axis.spatial(12, i1 * 4 + ax1) + ax2_3 = T.axis.spatial(12, i2 * 4 + ax2) + ax3_3 = T.axis.spatial(16, i3 * 16 + ax3) + rv0, rv1 = T.axis.remap("RR", [ax4, ax5]) + with T.init(): + C[ax0_3, ax1_3, ax2_3, ax3_3] = T.int8(-128) + C[ax0_3, ax1_3, ax2_3, ax3_3] = T.max( + C[ax0_3, ax1_3, ax2_3, ax3_3], B[ax0_3, ax1_3 + rv0, ax2_3 + rv1, ax3_3] + ) + + sch = tir.Schedule(func_multi_writers, debug_mask="all") + with pytest.raises(tvm.tir.ScheduleError): + sch.rolling_buffer(sch.get_block("B_writer_0"), 0) + + +def test_fail_rolling_buffer_not_match(): @T.prim_func def func_non_overlap( A: T.Buffer[(1, 12, 12, 16), "int8"], C: T.Buffer[(1, 12, 12, 16), "int8"] @@ -525,7 +560,7 @@ def func_non_overlap( sch.rolling_buffer(sch.get_block("B"), 0) -def test_rolling_buffer_injection_invalid(): +def test_fail_rolling_buffer_injection_invalid(): sch = tir.Schedule(cascade_2_max_pool2d, debug_mask="all") # Block B is not compute_at to Block C, so rolling_buffer injection is invalid. _, _ = _tile_nd(sch, [1, 4, 8, 16], "C")