From b723d38f3d04cbe5022412966c72b0ad7174ac42 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 9 Aug 2021 16:25:20 -0400 Subject: [PATCH 1/5] [TensorIR][M2a] Storage Align This PR is part of the TensorIR upstreaming effort (#7527), which adds the one schedule primitive storage_align. Co-authored-by: Siyuan Feng Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Junru Shao --- include/tvm/tir/schedule/schedule.h | 12 + python/tvm/tir/schedule/schedule.py | 71 ++++ src/tir/schedule/analysis.h | 24 ++ src/tir/schedule/analysis/analysis.cc | 48 +++ src/tir/schedule/concrete_schedule.cc | 10 + src/tir/schedule/concrete_schedule.h | 3 + src/tir/schedule/primitive.h | 12 + src/tir/schedule/primitive/block_annotate.cc | 314 ++++++++++++++++++ src/tir/schedule/schedule.cc | 3 + src/tir/schedule/traced_schedule.cc | 13 + src/tir/schedule/traced_schedule.h | 3 + src/tir/transforms/compact_buffer_region.cc | 79 ++++- .../test_tir_schedule_storage_align.py | 182 ++++++++++ ...est_tir_transform_compact_buffer_region.py | 49 +++ 14 files changed, 820 insertions(+), 3 deletions(-) create mode 100644 src/tir/schedule/primitive/block_annotate.cc create mode 100644 tests/python/unittest/test_tir_schedule_storage_align.py diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index e2083778431e..25a2022281d2 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -264,6 +264,18 @@ class ScheduleNode : public runtime::Object { * \return The rfactor block */ virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0; + /******** Schedule: Block annotation ********/ + /*! + * \brief Set alignment requirement for specific dimension such that + * stride[axis] == k * factor + offset for some k. + * \param block_rv The producer block of the buffer + * \param buffer_index The index of the buffer in block's write region + * \param axis The dimension to be specified for alignment + * \param factor The factor multiple of alignment + * \param offset The required offset factor + */ + virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, + int offset) = 0; /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ /******** Schedule: Misc ********/ diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 4bbb5b9b1582..6e381443b52d 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -710,6 +710,77 @@ def after_rfactor(a: ty.handle, b: ty.handle) -> None: """ return _ffi_api.ScheduleRFactor(self, loop, factor_axis) # type: ignore # pylint: disable=no-member + ######## Schedule: Block annotatoin ######## + + def storage_align( # pylint: disable=too-many-arguments + self, block: BlockRV, buffer_index: int, axis: int, factor: int, offset: int + ) -> None: + """Set alignment requirement for specific dimension such that + stride[axis] == k * factor + offset for some k. + + Parameters + ---------- + block : BlockRV + The producer block of the buffer. + buffer_index : int + The index of the buffer in block's write region. + axis : int + The dimension to be specified for alignment. + factor : int + The factor multiple of alignment. + offset : int + The required offset factor. + + Examples + -------- + + Before storage_align, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_storage_align(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + Create the schedule and do storage_align: + + .. code-block:: python + + sch = tir.Schedule(before_storage_align) + sch.storage_align(sch.get_block("B"), buffer_index=0, axis=0, factor=128, offset=1) + print(tvm.script.asscript(sch.mod["main"])) + + After applying rfactor, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_storage_align(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + tir.block_attr({"buffer_dim_align": [[[0, 128, 1]]]}) + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + After lowering passes, buffer B will have strides as [129, 1]. + + Note + ---- + Storage_align requires the buffer to be an intermediate buffer defined via `alloc_buffer`. + """ + _ffi_api.ScheduleStorageAlign( # type: ignore # pylint: disable=no-member + self, block, buffer_index, axis, factor, offset + ) + ########## Schedule: Blockize & Tensorize ########## ########## Schedule: Annotation ########## diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 9baf4b5245ea..cc7df0dbefbb 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -202,6 +202,19 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self */ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref); +/******** Block-buffer relation ********/ + +/*! + * \brief Get the BlockRealize of the single child block of the block or loop specified by + * `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple child blocks + * \param self The schedule state + * \param block The queried block + * \param n The index of the queried buffer + * \return The buffer of the n-th write region of the block. + * \throw ScheduleError If the buffer index is out of bound. + */ +Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n); + /******** Commutative Reducer ********/ /*! @@ -224,6 +237,17 @@ std::vector> GetReducerGetters(); bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs); +/******** Annotation ********/ + +/*! + * \brief Create a new block with the given annotation added + * \param block The block with original annotation + * \param attr_key The annotation key to be added + * \param attr_value The annotation value to be added + * \return A new block with the given annotation as its last annotation + */ +Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 3ee98ec5b7d2..f5e9a675e14b 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -527,6 +527,54 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr } } +/******** Block-buffer relation ********/ + +Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) { + class WriteBufferIndexOutOfRangeError : public ScheduleError { + public: + explicit WriteBufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index) + : mod_(std::move(mod)), block_(std::move(block)), buffer_index_(buffer_index) {} + + String FastErrorString() const final { + return "ScheduleError: The input `buffer_index` is out of range. It is required to be in " + "range [0, num_write_regions) where `num_write_regions` is the number of buffer " + "regions written by the block."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + size_t num_writes = block_->writes.size(); + os << "The block {0} has " << num_writes + << " write regions, so `buffer_index` is required to be in [0, " << num_writes + << "). However, the input `buffer_index` is " << buffer_index_ + << ", which is out of the expected range"; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + int buffer_index_; + }; + + if (n < 0 || static_cast(n) >= block->writes.size()) { + throw WriteBufferIndexOutOfRangeError(self->mod, block, n); + } + return block->writes[n]->buffer; +} + +/******** Annotation ********/ +Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value) { + Map annotations = block->annotations; + annotations.Set(attr_key, attr_value); + ObjectPtr new_block = make_object(*block); + new_block->annotations = std::move(annotations); + return Block(new_block); +} + /******** Pattern Matcher ********/ /*! diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 610628c6d88a..688ea8059c0e 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -362,6 +362,16 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { } /******** Schedule: loop binding/annotation ********/ +/******** Schedule: block annotation ********/ + +void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, + int factor, int offset) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::StorageAlign(state_, this->GetSRef(block_rv), buffer_index, axis, factor, offset); + TVM_TIR_SCHEDULE_END("storage-align", this->error_render_level_); + this->state_->DebugVerify(); +} + /******** Schedule: cache read/write ********/ /******** Schedule: reduction ********/ diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index ec0dd079243b..cfdd9c8452f7 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -88,6 +88,9 @@ class ConcreteScheduleNode : public ScheduleNode { void ReverseComputeInline(const BlockRV& block) override; /******** Schedule: Reduction ********/ BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override; + /******** Schedule: Block annotation ********/ + void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, + int offset) override; /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ /******** Schedule: Misc ********/ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 22e25f1c54a7..01ee59038430 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -104,6 +104,18 @@ TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref * \return The sref of the rfactor block */ TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis); +/******** Schedule: Block annotation ********/ +/*! + * \brief Set alignment requirement for specific dimension such that + * stride[axis] == k * factor + offset for some k + * \param block_sref The producer block of the buffer + * \param buffer_index The index of the buffer in block's write region + * \param axis The dimension to be specified for alignment + * \param factor The factor multiple of alignment + * \param offset The required offset factor + */ +TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + int axis, int factor, int offset); /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ /******** Schedule: Misc ********/ diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc new file mode 100644 index 000000000000..73d4b8919bfd --- /dev/null +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +class StorageAlignAxisOutOfRangeError : public ScheduleError { + public: + explicit StorageAlignAxisOutOfRangeError(IRModule mod, Buffer buffer, int axis) + : mod_(std::move(mod)), buffer_(std::move(buffer)), axis_(axis) {} + + String FastErrorString() const final { + return "ScheduleError: The input `axis` is out of range. It is required to be in range " + "[-ndim, ndim) where `ndim` is the number of dimensions of the buffer to set " + "storage alignment."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + int ndim = static_cast(buffer_->shape.size()); + os << "The buffer to set storage alignment " << buffer_->name << " has " << ndim + << " dimension(s), so `axis` is required to be in [" << -(ndim) << ", " << ndim + << ") for storage_align. However, the input `axis` is " << axis_ + << ", which is out of the expected range."; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + static int CheckAndUpdate(const IRModule& mod, const Buffer& buffer, int axis) { + int ndim = static_cast(buffer->shape.size()); + if (axis < -ndim || axis >= ndim) { + throw StorageAlignAxisOutOfRangeError(mod, buffer, axis); + } + // If axis is negative, convert it to a non-negative one. + if (axis < 0) { + axis += ndim; + } + return axis; + } + + private: + IRModule mod_; + Buffer buffer_; + int axis_; +}; + +/*! + * \brief Find the defining site of the buffer in the given block and its ancestors + * \param block_sref The block sref + * \param buffer The buffer + * \return The defining site of the buffer and whether the buffer is allocated (otherwise the + * buffer is from match_buffer). + */ +std::pair GetBufferDefiningSite(const StmtSRef& block_sref, const Buffer& buffer) { + // Climb up along the sref tree, and find the block where `buffer` is in alloc_buffers or + // match_buffers. + const StmtSRefNode* defining_site_sref = block_sref.get(); + while (defining_site_sref != nullptr) { + const auto* block = defining_site_sref->StmtAs(); + // If this sref is not a block sref, skip it. + if (block == nullptr) { + defining_site_sref = defining_site_sref->parent; + continue; + } + // Try to find the buffer in `allloc_buffers` + for (const Buffer& alloc_buffer : block->alloc_buffers) { + if (buffer.same_as(alloc_buffer)) { + return {GetRef(defining_site_sref), true}; + } + } + // We do not allow the buffer being defined in `match_buffer`. + for (const MatchBufferRegion match_buffer : block->match_buffers) { + if (buffer.same_as(match_buffer)) { + return {GetRef(defining_site_sref), false}; + } + } + defining_site_sref = defining_site_sref->parent; + } + // If we cannot find the defining site block, it means that the buffer must be in the function's + // buffer_map, which isn't an intermediate buffer. + return {StmtSRef(), false}; +} + +class NonAllocatedBufferError : public ScheduleError { + public: + explicit NonAllocatedBufferError(IRModule mod, Buffer buffer) : mod_(mod), buffer_(buffer) {} + + String FastErrorString() const final { + return "ScheduleError: The input buffer is not allocated by a block. This means the buffer is " + " either a function parameter or defined in `match_buffer` of a block."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The input buffer " << buffer_->name + << " is not allocated by a block. This means the buffer is either a function parameter or " + "defined in `match_buffer` of a block."; + return os.str(); + } + + static void CheckBufferAllocated(const IRModule& mod, const StmtSRef& block_sref, + const Buffer& buffer) { + StmtSRef defining_site_sref; + bool is_alloc; + std::tie(defining_site_sref, is_alloc) = GetBufferDefiningSite(block_sref, buffer); + if (!defining_site_sref.defined() || !is_alloc) { + throw NonAllocatedBufferError(mod, buffer); + } + } + + Array LocationsOfInterest() const final { return {}; } + IRModule mod() const final { return mod_; } + + private: + IRModule mod_; + Buffer buffer_; +}; + +class StorageAlignInvalidFactorError : public ScheduleError { + public: + explicit StorageAlignInvalidFactorError(IRModule mod, int factor) + : mod_(std::move(mod)), factor_(factor) {} + + String FastErrorString() const final { + return "ScheduleError: The input `factor` of storage_align is expected to be a positive " + "number."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The input `factor` of storage_align is expected to be a positive number. However, the " + "input `factor` is " + << factor_ << ", which is out of the expected range."; + return os.str(); + } + + static void Check(const IRModule& mod, int factor) { + if (factor <= 0) { + throw StorageAlignInvalidFactorError(mod, factor); + } + } + + Array LocationsOfInterest() const final { return {}; } + IRModule mod() const final { return mod_; } + + private: + IRModule mod_; + int factor_; +}; + +class StorageAlignInvalidAnnotationError : public ScheduleError { + public: + explicit StorageAlignInvalidAnnotationError(IRModule mod, Block block) + : mod_(std::move(mod)), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: The block annotation for storage align is expected to be an array of " + "3-integer-tuples (axis, factor, offset)."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The block annotation for storage align is expected to be an array of 3-integer-tuples " + "(axis, factor, offset). However, the block annotation with key " + << attr::buffer_dim_align << " of the block {0} is " + << block_->annotations.at(attr::buffer_dim_align) << ", which is unexpected."; + return os.str(); + } + + static Array>> CheckAndGetAnnotation(const IRModule& mod, + const Block& block) { + // Get existing annotation value. + auto it = block->annotations.find(attr::buffer_dim_align); + if (it != block->annotations.end()) { + if (!IsValidAnnotation(block, (*it).second)) { + throw StorageAlignInvalidAnnotationError(mod, block); + } + return Downcast>>>((*it).second); + } + + // Create new annotation value + Array>> storage_align_annotation; + storage_align_annotation.resize(block->writes.size()); + return storage_align_annotation; + } + + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod() const final { return mod_; } + + private: + static bool IsValidAnnotation(const Block& block, const ObjectRef& anno_value) { + if (!anno_value->IsInstance()) { + return false; + } + const auto& buffer_annotations = Downcast>(anno_value); + if (buffer_annotations.size() != block->writes.size()) { + return false; + } + for (const ObjectRef& buffer_annotation : buffer_annotations) { + if (!buffer_annotation->IsInstance()) { + return false; + } + const auto& dim_annotations = Downcast>(buffer_annotation); + for (const ObjectRef& dim_annotation : dim_annotations) { + if (!dim_annotation->IsInstance()) { + return false; + } + const auto& dim_anno = Downcast>(dim_annotation); + // Check if the annotations are consist of 3-tuples. + if (dim_anno.size() != 3) { + return false; + } + for (const ObjectRef& dim_anno_element : dim_anno) { + if (!dim_anno_element->IsInstance()) { + return false; + } + } + } + } + return true; + } + + IRModule mod_; + Block block_; +}; + +void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis, + int factor, int offset) { + const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); + Buffer buffer = GetNthWriteBuffer(self, GetRef(block_ptr), buffer_index); + StorageAlignInvalidFactorError::Check(self->mod, factor); + axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, axis); + NonAllocatedBufferError::CheckBufferAllocated(self->mod, block_sref, buffer); + + // Step 1: Get existing or create new annotation value. + auto storage_align_annotation = StorageAlignInvalidAnnotationError::CheckAndGetAnnotation( + self->mod, GetRef(block_ptr)); + + // Step 2: Update the annotation value + Array> buffer_storage_align = storage_align_annotation[buffer_index]; + bool found = false; + for (size_t j = 0; j < buffer_storage_align.size(); ++j) { + ICHECK(buffer_storage_align[j].size() == 3); + if (buffer_storage_align[j][0] == axis) { + buffer_storage_align.Set(j, {Integer(axis), Integer(factor), Integer(offset)}); + found = true; + break; + } + } + if (!found) { + buffer_storage_align.push_back({Integer(axis), Integer(factor), Integer(offset)}); + } + storage_align_annotation.Set(buffer_index, std::move(buffer_storage_align)); + + // Step 3: Replace the block with the new annotation + Block new_block = WithAnnotation(block_ptr, attr::buffer_dim_align, storage_align_annotation); + self->Replace(block_sref, new_block, {{GetRef(block_ptr), new_block}}); +} + +/******** Instruction Registration ********/ + +struct StorageAlignTraits : public UnpackedInstTraits { + static constexpr const char* kName = "StorageAlign"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 4; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, + Integer axis, Integer factor, Integer offset) { + return sch->StorageAlign(block_rv, buffer_index->value, axis->value, factor->value, + offset->value); + } + + static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, + Integer axis, Integer factor, Integer offset) { + PythonAPICall py("storage_align"); + py.Input("block", block_rv); + py.Input("buffer_index", buffer_index); + py.Input("axis", axis); + py.Input("factor", factor); + py.Input("offset", offset); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(StorageAlignTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 3232a3344ee7..d6dc0b446e16 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -135,6 +135,9 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline") /******** (FFI) Reduction ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor") .set_body_method(&ScheduleNode::RFactor); +/******** (FFI) Block annotation ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") + .set_body_method(&ScheduleNode::StorageAlign); /******** (FFI) Blockize & Tensorize ********/ /******** (FFI) Annotation ********/ /******** (FFI) Misc ********/ diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index d664d7f6ce98..e0ffdc7b019f 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -137,6 +137,19 @@ BlockRV TracedScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { return result; } +/******** Schedule: Block annotation ********/ + +void TracedScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, + int factor, int offset) { + ConcreteScheduleNode::StorageAlign(block_rv, buffer_index, axis, factor, offset); + static const InstructionKind& kind = InstructionKind::Get("StorageAlign"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(buffer_index), Integer(axis), Integer(factor), Integer(offset)}, + /*outputs=*/{})); +} + /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index b4518cbba8b5..4650c44ba8c3 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -61,6 +61,9 @@ class TracedScheduleNode : public ConcreteScheduleNode { void ReverseComputeInline(const BlockRV& block_rv) final; /******** Schedule: Reduction ********/ BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) final; + /******** Schedule: Block annotation ********/ + void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, + int offset) final; /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ /******** Schedule: Misc ********/ diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index b1a4fd45ef0d..a549061b5f1b 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -303,18 +303,62 @@ class BufferAccessRegionCollector : public StmtExprVisitor { support::Arena arena_; }; +/*! \brief Collect storage alignment information from block annotations. */ +class StorageAlignCollector : public StmtVisitor { + public: + static std::unordered_map>, ObjectPtrHash, ObjectPtrEqual> Collect( + const PrimFunc& f) { + StorageAlignCollector collector; + collector(f->body); + return std::move(collector.storage_align_); + } + + private: + void VisitStmt_(const BlockNode* op) final { + auto it = op->annotations.find(attr::buffer_dim_align); + if (it != op->annotations.end()) { + const auto& storage_align = Downcast>>>((*it).second); + ICHECK(storage_align.size() == op->writes.size()); + for (size_t i = 0; i < storage_align.size(); ++i) { + CHECK(!storage_align_.count(op->writes[i]->buffer)) + << "ValueError: Conflicting storage_align for buffer " << op->writes[i]->buffer->name; + storage_align_.emplace(op->writes[i]->buffer, storage_align[i]); + } + } + StmtVisitor::VisitStmt_(op); + } + + /*! \brief The map from Buffer to its storage alignment information. */ + std::unordered_map>, ObjectPtrHash, ObjectPtrEqual> storage_align_; +}; + /*! \brief Reallocate the buffers with minimal region. */ class BufferCompactor : public StmtExprMutator { public: static Stmt Compact( const PrimFunc& f, - const std::unordered_map& regions) { + const std::unordered_map& regions, + const std::unordered_map>, ObjectPtrHash, ObjectPtrEqual>& + storage_align) { std::unordered_map buffer_info; for (const auto& kv : regions) { const Buffer& buffer = kv.first; Region region = kv.second; - buffer_info.emplace(buffer, BufferAllocInfo(std::move(region))); + BufferAllocInfo buffer_alloc_info(std::move(region)); + auto it = storage_align.find(buffer); + if (it != storage_align.end()) { + std::vector dim_aligns(buffer->shape.size()); + for (const Array& dim_align : (*it).second) { + ICHECK(dim_align.size() == 3); + int dim = dim_align[0]->value; + int factor = dim_align[1]->value; + int offset = dim_align[2]->value; + dim_aligns.at(dim) = {factor, offset}; + } + buffer_alloc_info.dim_aligns = std::move(dim_aligns); + } + buffer_info.emplace(buffer, std::move(buffer_alloc_info)); } BufferCompactor compactor(std::move(buffer_info)); Stmt stmt = compactor(f->body); @@ -322,9 +366,19 @@ class BufferCompactor : public StmtExprMutator { } private: + /*! \brief The storage alignment for a dimension */ + struct DimAlignInfo { + /*! \brief The factor of the alignment */ + int align_factor{0}; + /*! \brief The offset of the alignment */ + int align_offset{0}; + }; + struct BufferAllocInfo { /*! \brief The buffer access region. */ Region region; + /*! \brief The storage alignment information. */ + std::vector dim_aligns; /*! * \brief The reallocated buffer with minimal size. * \note The value if NullOpt if the buffer do not need reallocate (e.g parameter buffer). @@ -380,8 +434,25 @@ class BufferCompactor : public StmtExprMutator { for (const Range& range : info.region) { shape.push_back(range->extent); } + Array strides; + if (info.dim_aligns.size()) { + ICHECK(info.dim_aligns.size() == shape.size()); + strides.resize(shape.size()); + PrimExpr stride = make_const(shape[0].dtype(), 1); + for (size_t i = shape.size(); i != 0; --i) { + size_t dim = i - 1; + if (info.dim_aligns[dim].align_factor != 0) { + PrimExpr factor = make_const(stride.dtype(), info.dim_aligns[dim].align_factor); + PrimExpr offset = make_const(stride.dtype(), info.dim_aligns[dim].align_offset); + stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); + } + strides.Set(dim, stride); + stride = stride * shape[dim]; + } + } ObjectPtr n = make_object(*buffer.get()); n->shape = std::move(shape); + n->strides = std::move(strides); info.new_buffer = Buffer(std::move(n)); result.push_back(info.new_buffer); } @@ -458,7 +529,9 @@ PrimFunc CompactBufferAllocation(PrimFunc f) { PrimFuncNode* fptr = f.CopyOnWrite(); std::unordered_map region = BufferAccessRegionCollector::Collect(f); - fptr->body = BufferCompactor::Compact(f, region); + std::unordered_map>, ObjectPtrHash, ObjectPtrEqual> storage_align = + StorageAlignCollector::Collect(f); + fptr->body = BufferCompactor::Compact(f, region, storage_align); return f; } else { return f; diff --git a/tests/python/unittest/test_tir_schedule_storage_align.py b/tests/python/unittest/test_tir_schedule_storage_align.py new file mode 100644 index 000000000000..33b430af43dd --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_storage_align.py @@ -0,0 +1,182 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name + +@tvm.script.tir +def element_wise(a: ty.handle, c: ty.handle) -> None: + C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with tir.block([], "root"): + tir.reads([]) + tir.writes([]) + B = tir.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0 in tir.serial(0, 128): + for ax1 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i0) + tir.bind(vj, ax1) + tir.reads([A[vi, vj]]) + tir.writes([B[vi, vj]]) + B[vi, vj] = (A[vi, vj]*tir.float32(2)) + for i1 in tir.serial(0, 128): + with tir.block([128, 128], "C") as [vi_1, vj_1]: + tir.bind(vi_1, i0) + tir.bind(vj_1, i1) + tir.reads([B[vi_1, vj_1]]) + tir.writes([C[vi_1, vj_1]]) + C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1)) + + +@tvm.script.tir +def element_wise_storage_align(a: ty.handle, c: ty.handle) -> None: + C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with tir.block([], "root"): + tir.reads([]) + tir.writes([]) + B = tir.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0 in tir.serial(0, 128): + for ax1 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i0) + tir.bind(vj, ax1) + tir.reads([A[vi, vj]]) + tir.writes([B[vi, vj]]) + tir.block_attr({"buffer_dim_align":[[[0, 128, 127]]]}) + B[vi, vj] = (A[vi, vj]*tir.float32(2)) + for i1 in tir.serial(0, 128): + with tir.block([128, 128], "C") as [vi_1, vj_1]: + tir.bind(vi_1, i0) + tir.bind(vj_1, i1) + tir.reads([B[vi_1, vj_1]]) + tir.writes([C[vi_1, vj_1]]) + C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1)) + + +@tvm.script.tir +def element_wise_invalid_annotation(a: ty.handle, c: ty.handle) -> None: + C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with tir.block([], "root"): + tir.reads([]) + tir.writes([]) + B = tir.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0 in tir.serial(0, 128): + for ax1 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.block_attr({"buffer_dim_align": [0]}) + tir.bind(vi, i0) + tir.bind(vj, ax1) + tir.reads([A[vi, vj]]) + tir.writes([B[vi, vj]]) + B[vi, vj] = (A[vi, vj]*tir.float32(2)) + for i1 in tir.serial(0, 128): + with tir.block([128, 128], "C") as [vi_1, vj_1]: + tir.bind(vi_1, i0) + tir.bind(vj_1, i1) + tir.reads([B[vi_1, vj_1]]) + tir.writes([C[vi_1, vj_1]]) + C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1)) + + +def test_storage_align(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + s.storage_align(B, 0, axis=0, factor=128, offset=127) + tvm.ir.assert_structural_equal(element_wise_storage_align, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_storage_align_update(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + s.storage_align(B, 0, axis=0, factor=128, offset=0) + s.storage_align(B, 0, axis=0, factor=128, offset=127) + tvm.ir.assert_structural_equal(element_wise_storage_align, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_storage_align_invalid_factor1(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + with pytest.raises(tir.ScheduleError): + s.storage_align(B, 0, axis=0, factor=0, offset=127) + + +def test_storage_align_invalid_factor2(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + with pytest.raises(tir.ScheduleError): + s.storage_align(B, 0, axis=0, factor=-1, offset=127) + + +def test_storage_align_invalid_buffer(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + C = s.get_block("C") + with pytest.raises(tir.ScheduleError): + s.storage_align(C, 0, axis=0, factor=128, offset=127) + + +def test_storage_align_invalid_buffer_index(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + with pytest.raises(tir.ScheduleError): + s.storage_align(B, 2, axis=0, factor=128, offset=127) + + +def test_storage_align_invalid_axis(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + with pytest.raises(tir.ScheduleError): + s.storage_align(B, 0, axis=2, factor=128, offset=127) + + +def test_storage_align_invalid_annotation(): + func = element_wise_invalid_annotation + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + with pytest.raises(tir.ScheduleError): + s.storage_align(B, 0, axis=2, factor=128, offset=127) + + +if __name__ == "__main__": + test_storage_align() + test_storage_align_update() + test_storage_align_invalid_factor1() + test_storage_align_invalid_factor2() + test_storage_align_invalid_buffer() + test_storage_align_invalid_buffer_index() + test_storage_align_invalid_axis() + test_storage_align_invalid_annotation() diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index fb53b420f4ce..4ebe4af88434 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -339,6 +339,50 @@ def compacted_match_buffer_func(a: ty.handle, c: ty.handle) -> None: C1[()] = B2[()] * 2.0 +@tvm.script.tir +def storage_align_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + for i in range(0, 16): + with tir.block([]): + tir.reads(A[i, 0:16]) + tir.writes(C[i, 0:16]) + B = tir.alloc_buffer((16, 16), "float32") + for j in range(0, 16): + with tir.block([]) as []: + tir.reads(A[i, j]) + tir.writes(B[i, j]) + tir.block_attr({"buffer_dim_align": [[[0, 16, 15]]]}) + B[i, j] = A[i, j] + 1.0 + for j in range(0, 16): + with tir.block([]) as []: + tir.reads(B[i, j]) + tir.writes(C[i, j]) + C[i, j] = B[i, j] * 2.0 + + +@tvm.script.tir +def compacted_storage_align_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + for i in range(0, 16): + with tir.block([]): + tir.reads(A[i, 0:16]) + tir.writes(C[i, 0:16]) + B = tir.alloc_buffer((1, 16), strides=(31, 1), dtypes="float32") + for j in range(0, 16): + with tir.block() as []: + tir.reads(A[i, j]) + tir.writes(B[0, j]) + tir.block_attr({"buffer_dim_align": [[[0, 16, 15]]]}) + B[0, j] = A[i, j] + 1.0 + for j in range(0, 16): + with tir.block() as []: + tir.reads(B[0, j]) + tir.writes(C[i, j]) + C[i, j] = B[0, j] * 2.0 + + def test_elementwise(): _check(elementwise_func, compacted_elementwise_func) @@ -380,6 +424,10 @@ def test_lower_te(): tvm.ir.assert_structural_equal(mod, orig_mod) # CompactBufferAllocation should do nothing on TE +def test_storage_align(): + _check(storage_align_func, compacted_storage_align_func) + + if __name__ == "__main__": test_elementwise() test_unschedulable_block() @@ -389,3 +437,4 @@ def test_lower_te(): test_symbolic() test_complex() test_match_buffer() + test_storage_align() From d33af69a75d1ea074ac6a35191076dda4403de70 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 9 Aug 2021 19:14:04 -0400 Subject: [PATCH 2/5] Update src/tir/schedule/primitive/block_annotate.cc Co-authored-by: Tristan Konolige --- src/tir/schedule/primitive/block_annotate.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index 73d4b8919bfd..f728267a32b1 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -35,7 +35,7 @@ class StorageAlignAxisOutOfRangeError : public ScheduleError { String DetailRenderTemplate() const final { std::ostringstream os; int ndim = static_cast(buffer_->shape.size()); - os << "The buffer to set storage alignment " << buffer_->name << " has " << ndim + os << "The buffer to set storage alignment of, " << buffer_->name << ", has " << ndim << " dimension(s), so `axis` is required to be in [" << -(ndim) << ", " << ndim << ") for storage_align. However, the input `axis` is " << axis_ << ", which is out of the expected range."; From d9806e6b4c697459e125a0b5e632e1e9c41cb6f8 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 10 Aug 2021 21:22:06 -0400 Subject: [PATCH 3/5] Update src/tir/schedule/primitive/block_annotate.cc Co-authored-by: Junru Shao --- src/tir/schedule/primitive/block_annotate.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index f728267a32b1..cb15eeb715ba 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -70,7 +70,7 @@ class StorageAlignAxisOutOfRangeError : public ScheduleError { * \return The defining site of the buffer and whether the buffer is allocated (otherwise the * buffer is from match_buffer). */ -std::pair GetBufferDefiningSite(const StmtSRef& block_sref, const Buffer& buffer) { +std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, const Buffer& buffer) { // Climb up along the sref tree, and find the block where `buffer` is in alloc_buffers or // match_buffers. const StmtSRefNode* defining_site_sref = block_sref.get(); From f781f6b63e02c0ba5715bf28a9e9d5f1dbbf4957 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 10 Aug 2021 21:25:48 -0400 Subject: [PATCH 4/5] Address comments --- include/tvm/tir/schedule/schedule.h | 5 ++- python/tvm/tir/schedule/schedule.py | 4 +- src/tir/schedule/analysis.h | 11 ------ src/tir/schedule/analysis/analysis.cc | 9 ----- src/tir/schedule/primitive.h | 10 ++++- src/tir/schedule/primitive/block_annotate.cc | 17 ++++---- src/tir/schedule/transform.cc | 35 +++++++++++++++++ src/tir/schedule/transform.h | 41 ++++++++++++++++++++ src/tir/transforms/compact_buffer_region.cc | 12 +++--- 9 files changed, 108 insertions(+), 36 deletions(-) create mode 100644 src/tir/schedule/transform.cc create mode 100644 src/tir/schedule/transform.h diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 25a2022281d2..e5d2c440e57b 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -267,7 +267,10 @@ class ScheduleNode : public runtime::Object { /******** Schedule: Block annotation ********/ /*! * \brief Set alignment requirement for specific dimension such that - * stride[axis] == k * factor + offset for some k. + * stride[axis] == k * factor + offset for some k. This is useful to set memory layout for + * more friendly memory access pattern. For example, we can set alignment to be factor=2, + * offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared + * memory. * \param block_rv The producer block of the buffer * \param buffer_index The index of the buffer in block's write region * \param axis The dimension to be specified for alignment diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 6e381443b52d..e8415d2bd522 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -716,7 +716,9 @@ def storage_align( # pylint: disable=too-many-arguments self, block: BlockRV, buffer_index: int, axis: int, factor: int, offset: int ) -> None: """Set alignment requirement for specific dimension such that - stride[axis] == k * factor + offset for some k. + stride[axis] == k * factor + offset for some k. This is useful to set memory layout for more + friendly memory access pattern. For example, we can set alignment to be factor=2, offset=1 + to avoid bank conflict for thread access on higher dimension in GPU shared memory. Parameters ---------- diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index cc7df0dbefbb..370aa01a33c0 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -237,17 +237,6 @@ std::vector> GetReducerGetters(); bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs); -/******** Annotation ********/ - -/*! - * \brief Create a new block with the given annotation added - * \param block The block with original annotation - * \param attr_key The annotation key to be added - * \param attr_value The annotation value to be added - * \return A new block with the given annotation as its last annotation - */ -Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value); - } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index f5e9a675e14b..8d1913fdee86 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -566,15 +566,6 @@ Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) { return block->writes[n]->buffer; } -/******** Annotation ********/ -Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value) { - Map annotations = block->annotations; - annotations.Set(attr_key, attr_value); - ObjectPtr new_block = make_object(*block); - new_block->annotations = std::move(annotations); - return Block(new_block); -} - /******** Pattern Matcher ********/ /*! diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 01ee59038430..4b9c76947bb1 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -107,7 +107,10 @@ TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int fact /******** Schedule: Block annotation ********/ /*! * \brief Set alignment requirement for specific dimension such that - * stride[axis] == k * factor + offset for some k + * stride[axis] == k * factor + offset for some k. This is useful to set memory layout for + * more friendly memory access pattern. For example, we can set alignment to be factor=2, + * offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared + * memory. * \param block_sref The producer block of the buffer * \param buffer_index The index of the buffer in block's write region * \param axis The dimension to be specified for alignment @@ -116,6 +119,11 @@ TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int fact */ TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis, int factor, int offset); + +/******** Annotation types for StorageAlign ********/ +using StorageAlignTuple = Array; // (buffer_idx, axis, factor, offset) +using StorageAlignAnnotation = Array; // unordered array of StorageAlignTuple + /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ /******** Schedule: Misc ********/ diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index cb15eeb715ba..cff527bf975f 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include "../transform.h" #include "../utils.h" namespace tvm { @@ -70,7 +71,8 @@ class StorageAlignAxisOutOfRangeError : public ScheduleError { * \return The defining site of the buffer and whether the buffer is allocated (otherwise the * buffer is from match_buffer). */ -std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, const Buffer& buffer) { +std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, + const Buffer& buffer) { // Climb up along the sref tree, and find the block where `buffer` is in alloc_buffers or // match_buffers. const StmtSRefNode* defining_site_sref = block_sref.get(); @@ -97,7 +99,7 @@ std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_ } // If we cannot find the defining site block, it means that the buffer must be in the function's // buffer_map, which isn't an intermediate buffer. - return {StmtSRef(), false}; + return {NullOpt, false}; } class NonAllocatedBufferError : public ScheduleError { @@ -119,10 +121,10 @@ class NonAllocatedBufferError : public ScheduleError { static void CheckBufferAllocated(const IRModule& mod, const StmtSRef& block_sref, const Buffer& buffer) { - StmtSRef defining_site_sref; + Optional defining_site_sref; bool is_alloc; std::tie(defining_site_sref, is_alloc) = GetBufferDefiningSite(block_sref, buffer); - if (!defining_site_sref.defined() || !is_alloc) { + if (!defining_site_sref || !is_alloc) { throw NonAllocatedBufferError(mod, buffer); } } @@ -194,7 +196,7 @@ class StorageAlignInvalidAnnotationError : public ScheduleError { if (!IsValidAnnotation(block, (*it).second)) { throw StorageAlignInvalidAnnotationError(mod, block); } - return Downcast>>>((*it).second); + return Downcast>((*it).second); } // Create new annotation value @@ -252,8 +254,9 @@ void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_ind NonAllocatedBufferError::CheckBufferAllocated(self->mod, block_sref, buffer); // Step 1: Get existing or create new annotation value. - auto storage_align_annotation = StorageAlignInvalidAnnotationError::CheckAndGetAnnotation( - self->mod, GetRef(block_ptr)); + Array storage_align_annotation = + StorageAlignInvalidAnnotationError::CheckAndGetAnnotation(self->mod, + GetRef(block_ptr)); // Step 2: Update the annotation value Array> buffer_storage_align = storage_align_annotation[buffer_index]; diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc new file mode 100644 index 000000000000..c296d7f8520b --- /dev/null +++ b/src/tir/schedule/transform.cc @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./transform.h" + +namespace tvm { +namespace tir { + +/******** Annotation ********/ +Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value) { + Map annotations = block->annotations; + annotations.Set(attr_key, attr_value); + ObjectPtr new_block = make_object(*block); + new_block->annotations = std::move(annotations); + return Block(new_block); +} + +} // namespace tir +} // namespace tvm \ No newline at end of file diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h new file mode 100644 index 000000000000..083293e42e03 --- /dev/null +++ b/src/tir/schedule/transform.h @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_TRANSFORM_H_ +#define TVM_TIR_SCHEDULE_TRANSFORM_H_ + +#include + +namespace tvm { +namespace tir { + +/******** Annotation ********/ + +/*! + * \brief Create a new block with the given annotation added + * \param block The block with original annotation + * \param attr_key The annotation key to be added + * \param attr_value The annotation value to be added + * \return A new block with the given annotation as its last annotation + */ +Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value); + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_TRANSFORM_H_ \ No newline at end of file diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index a549061b5f1b..9fe14476ed8b 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -306,7 +306,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { /*! \brief Collect storage alignment information from block annotations. */ class StorageAlignCollector : public StmtVisitor { public: - static std::unordered_map>, ObjectPtrHash, ObjectPtrEqual> Collect( + static std::unordered_map Collect( const PrimFunc& f) { StorageAlignCollector collector; collector(f->body); @@ -317,7 +317,7 @@ class StorageAlignCollector : public StmtVisitor { void VisitStmt_(const BlockNode* op) final { auto it = op->annotations.find(attr::buffer_dim_align); if (it != op->annotations.end()) { - const auto& storage_align = Downcast>>>((*it).second); + const auto& storage_align = Downcast>((*it).second); ICHECK(storage_align.size() == op->writes.size()); for (size_t i = 0; i < storage_align.size(); ++i) { CHECK(!storage_align_.count(op->writes[i]->buffer)) @@ -329,7 +329,7 @@ class StorageAlignCollector : public StmtVisitor { } /*! \brief The map from Buffer to its storage alignment information. */ - std::unordered_map>, ObjectPtrHash, ObjectPtrEqual> storage_align_; + std::unordered_map storage_align_; }; /*! \brief Reallocate the buffers with minimal region. */ @@ -338,7 +338,7 @@ class BufferCompactor : public StmtExprMutator { static Stmt Compact( const PrimFunc& f, const std::unordered_map& regions, - const std::unordered_map>, ObjectPtrHash, ObjectPtrEqual>& + const std::unordered_map& storage_align) { std::unordered_map buffer_info; @@ -529,8 +529,8 @@ PrimFunc CompactBufferAllocation(PrimFunc f) { PrimFuncNode* fptr = f.CopyOnWrite(); std::unordered_map region = BufferAccessRegionCollector::Collect(f); - std::unordered_map>, ObjectPtrHash, ObjectPtrEqual> storage_align = - StorageAlignCollector::Collect(f); + std::unordered_map + storage_align = StorageAlignCollector::Collect(f); fptr->body = BufferCompactor::Compact(f, region, storage_align); return f; } else { From fd3f0ded2aac6f622de130fe72c973312c5d690f Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 12 Aug 2021 15:54:46 -0400 Subject: [PATCH 5/5] Address comments --- src/tir/schedule/primitive/block_annotate.cc | 61 ++++++++----------- src/tir/schedule/transform.cc | 2 +- src/tir/schedule/transform.h | 2 +- src/tir/transforms/compact_buffer_region.cc | 21 +++---- .../test_tir_schedule_storage_align.py | 2 +- ...est_tir_transform_compact_buffer_region.py | 4 +- 6 files changed, 41 insertions(+), 51 deletions(-) diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index cff527bf975f..937bc7c3802f 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -176,32 +176,30 @@ class StorageAlignInvalidAnnotationError : public ScheduleError { String FastErrorString() const final { return "ScheduleError: The block annotation for storage align is expected to be an array of " - "3-integer-tuples (axis, factor, offset)."; + "4-integer-tuples (buffer_index, axis, factor, offset)."; } String DetailRenderTemplate() const final { std::ostringstream os; - os << "The block annotation for storage align is expected to be an array of 3-integer-tuples " - "(axis, factor, offset). However, the block annotation with key " + os << "The block annotation for storage align is expected to be an array of 4-integer-tuples " + "(buffer_index, axis, factor, offset). However, the block annotation with key " << attr::buffer_dim_align << " of the block {0} is " << block_->annotations.at(attr::buffer_dim_align) << ", which is unexpected."; return os.str(); } - static Array>> CheckAndGetAnnotation(const IRModule& mod, - const Block& block) { + static StorageAlignAnnotation CheckAndGetAnnotation(const IRModule& mod, const Block& block) { // Get existing annotation value. auto it = block->annotations.find(attr::buffer_dim_align); if (it != block->annotations.end()) { if (!IsValidAnnotation(block, (*it).second)) { throw StorageAlignInvalidAnnotationError(mod, block); } - return Downcast>((*it).second); + return Downcast((*it).second); } // Create new annotation value - Array>> storage_align_annotation; - storage_align_annotation.resize(block->writes.size()); + StorageAlignAnnotation storage_align_annotation; return storage_align_annotation; } @@ -213,29 +211,20 @@ class StorageAlignInvalidAnnotationError : public ScheduleError { if (!anno_value->IsInstance()) { return false; } - const auto& buffer_annotations = Downcast>(anno_value); - if (buffer_annotations.size() != block->writes.size()) { - return false; - } - for (const ObjectRef& buffer_annotation : buffer_annotations) { - if (!buffer_annotation->IsInstance()) { + auto storage_align_annotations = Downcast>(anno_value); + for (const ObjectRef& storage_align_annotation : storage_align_annotations) { + if (!storage_align_annotation->IsInstance()) { return false; } - const auto& dim_annotations = Downcast>(buffer_annotation); - for (const ObjectRef& dim_annotation : dim_annotations) { - if (!dim_annotation->IsInstance()) { - return false; - } - const auto& dim_anno = Downcast>(dim_annotation); - // Check if the annotations are consist of 3-tuples. - if (dim_anno.size() != 3) { + auto storage_align_tuple = Downcast>(storage_align_annotation); + // Check if the annotation is a 4-tuple. + if (storage_align_tuple.size() != 4) { + return false; + } + for (const ObjectRef& tuple_element : storage_align_tuple) { + if (!tuple_element->IsInstance()) { return false; } - for (const ObjectRef& dim_anno_element : dim_anno) { - if (!dim_anno_element->IsInstance()) { - return false; - } - } } } return true; @@ -254,25 +243,27 @@ void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_ind NonAllocatedBufferError::CheckBufferAllocated(self->mod, block_sref, buffer); // Step 1: Get existing or create new annotation value. - Array storage_align_annotation = + StorageAlignAnnotation storage_align_annotation = StorageAlignInvalidAnnotationError::CheckAndGetAnnotation(self->mod, GetRef(block_ptr)); // Step 2: Update the annotation value - Array> buffer_storage_align = storage_align_annotation[buffer_index]; + // Array> buffer_storage_align = storage_align_annotation[buffer_index]; bool found = false; - for (size_t j = 0; j < buffer_storage_align.size(); ++j) { - ICHECK(buffer_storage_align[j].size() == 3); - if (buffer_storage_align[j][0] == axis) { - buffer_storage_align.Set(j, {Integer(axis), Integer(factor), Integer(offset)}); + StorageAlignTuple new_storage_align_tuple{Integer(buffer_index), Integer(axis), Integer(factor), + Integer(offset)}; + for (size_t j = 0; j < storage_align_annotation.size(); ++j) { + const auto& storage_align_tuple = storage_align_annotation[j]; + ICHECK(storage_align_tuple.size() == 4); + if (storage_align_tuple[0] == buffer_index && storage_align_tuple[1] == axis) { + storage_align_annotation.Set(j, std::move(new_storage_align_tuple)); found = true; break; } } if (!found) { - buffer_storage_align.push_back({Integer(axis), Integer(factor), Integer(offset)}); + storage_align_annotation.push_back(std::move(new_storage_align_tuple)); } - storage_align_annotation.Set(buffer_index, std::move(buffer_storage_align)); // Step 3: Replace the block with the new annotation Block new_block = WithAnnotation(block_ptr, attr::buffer_dim_align, storage_align_annotation); diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index c296d7f8520b..f27e0f6d62eb 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -32,4 +32,4 @@ Block WithAnnotation(const BlockNode* block, const String& attr_key, const Objec } } // namespace tir -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index 083293e42e03..53483829a303 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -38,4 +38,4 @@ Block WithAnnotation(const BlockNode* block, const String& attr_key, const Objec } // namespace tir } // namespace tvm -#endif // TVM_TIR_SCHEDULE_TRANSFORM_H_ \ No newline at end of file +#endif // TVM_TIR_SCHEDULE_TRANSFORM_H_ diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index 9fe14476ed8b..961ea1721fa1 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -317,12 +317,11 @@ class StorageAlignCollector : public StmtVisitor { void VisitStmt_(const BlockNode* op) final { auto it = op->annotations.find(attr::buffer_dim_align); if (it != op->annotations.end()) { - const auto& storage_align = Downcast>((*it).second); - ICHECK(storage_align.size() == op->writes.size()); - for (size_t i = 0; i < storage_align.size(); ++i) { - CHECK(!storage_align_.count(op->writes[i]->buffer)) - << "ValueError: Conflicting storage_align for buffer " << op->writes[i]->buffer->name; - storage_align_.emplace(op->writes[i]->buffer, storage_align[i]); + auto storage_align_annotation = Downcast((*it).second); + for (const auto& storage_align_tuple : storage_align_annotation) { + int buffer_index = storage_align_tuple[0]->value; + const Buffer& buffer = op->writes[buffer_index]->buffer; + storage_align_[buffer].push_back(storage_align_tuple); } } StmtVisitor::VisitStmt_(op); @@ -349,11 +348,11 @@ class BufferCompactor : public StmtExprMutator { auto it = storage_align.find(buffer); if (it != storage_align.end()) { std::vector dim_aligns(buffer->shape.size()); - for (const Array& dim_align : (*it).second) { - ICHECK(dim_align.size() == 3); - int dim = dim_align[0]->value; - int factor = dim_align[1]->value; - int offset = dim_align[2]->value; + for (const StorageAlignTuple& dim_align : (*it).second) { + ICHECK(dim_align.size() == 4); + int dim = dim_align[1]->value; + int factor = dim_align[2]->value; + int offset = dim_align[3]->value; dim_aligns.at(dim) = {factor, offset}; } buffer_alloc_info.dim_aligns = std::move(dim_aligns); diff --git a/tests/python/unittest/test_tir_schedule_storage_align.py b/tests/python/unittest/test_tir_schedule_storage_align.py index 33b430af43dd..a0a069347f95 100644 --- a/tests/python/unittest/test_tir_schedule_storage_align.py +++ b/tests/python/unittest/test_tir_schedule_storage_align.py @@ -66,7 +66,7 @@ def element_wise_storage_align(a: ty.handle, c: ty.handle) -> None: tir.bind(vj, ax1) tir.reads([A[vi, vj]]) tir.writes([B[vi, vj]]) - tir.block_attr({"buffer_dim_align":[[[0, 128, 127]]]}) + tir.block_attr({"buffer_dim_align":[[0, 0, 128, 127]]}) B[vi, vj] = (A[vi, vj]*tir.float32(2)) for i1 in tir.serial(0, 128): with tir.block([128, 128], "C") as [vi_1, vj_1]: diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 4ebe4af88434..15da022e67d6 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -352,7 +352,7 @@ def storage_align_func(a: ty.handle, c: ty.handle) -> None: with tir.block([]) as []: tir.reads(A[i, j]) tir.writes(B[i, j]) - tir.block_attr({"buffer_dim_align": [[[0, 16, 15]]]}) + tir.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) B[i, j] = A[i, j] + 1.0 for j in range(0, 16): with tir.block([]) as []: @@ -374,7 +374,7 @@ def compacted_storage_align_func(a: ty.handle, c: ty.handle) -> None: with tir.block() as []: tir.reads(A[i, j]) tir.writes(B[0, j]) - tir.block_attr({"buffer_dim_align": [[[0, 16, 15]]]}) + tir.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) B[0, j] = A[i, j] + 1.0 for j in range(0, 16): with tir.block() as []: