diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 22febfdfedec..1ae16fbae276 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -466,6 +466,11 @@ class ScheduleNode : public runtime::Object { */ virtual BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type) = 0; + /******** Schedule: Data movement ********/ + virtual BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) = 0; + virtual BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) = 0; /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 7a7ad2acedd7..d7074a7805be 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1609,6 +1609,37 @@ constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layo */ constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init"; +/*! + * \brief Mark that the block need to add predicate for block var bounds during lowering + */ +constexpr const char* require_block_var_bound_predicate = "require_bound_predicate"; + +/*! \brief Mark that tensor core is enabled in the PrimExpr */ +constexpr const char* meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled"; + +/*! + * \brief Mark a block as generated by cache_read or cache_write block. + * 0 means cache_read; 1 means cache_write. + * \sa meta_schedule_cache_type_read + * \sa meta_schedule_cache_type_write + */ +constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type"; + +/*! \sa meta_schedule_cache_type */ +constexpr const int meta_schedule_cache_type_read = 0; + +/*! \sa meta_schedule_cache_type */ +constexpr const int meta_schedule_cache_type_write = 1; + +/*! \brief Mark auto copy for memhammer */ +constexpr const char* auto_copy = "auto_copy"; + +/*! \brief Mark local stage constraint on data copy */ +constexpr const char* local_stage = "local_stage"; + +/*! \brief Mark vectorization length constraint on block */ +constexpr const char* vector_bytes = "vector_bytes"; + /*! * \brief Mark that a block is executed by a warp. This implies the extend of threadIdx.x is * warp size. diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 85b381a52950..0aaa8b3e8aec 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -647,6 +647,12 @@ TVM_DLL Pass BindParams(const Array& constants); */ TVM_DLL Pass ExtractPrimFuncConstants(); +/*! + * \brief Automatically do memory optimizations for auto copy blocks + * \return The pass. + */ +TVM_DLL Pass LowerAutoCopy(); + /*! * \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) * \return The pass. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 896e2fc48e72..7e8ac2164102 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1666,6 +1666,30 @@ def after_reindex( self, block, buffer_index, buffer_index_type_enum ) + ########## Schedule: Data movement ########## + + def read_at( + self, + loop: LoopRV, + block: BlockRV, + read_buffer_index: int, + storage_scope: str, + ) -> BlockRV: + return _ffi_api.ScheduleReadAt( # type: ignore # pylint: disable=no-member + self, loop, block, read_buffer_index, storage_scope + ) + + def write_at( + self, + loop: LoopRV, + block: BlockRV, + write_buffer_index: int, + storage_scope: str, + ) -> BlockRV: + return _ffi_api.ScheduleWriteAt( # type: ignore # pylint: disable=no-member + self, loop, block, write_buffer_index, storage_scope + ) + ########## Schedule: Compute location ########## @type_checked diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index bc3ec5b2ad74..a18d698e5426 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -926,6 +926,17 @@ def ExtractPrimFuncConstants(): return _ffi_api.ExtractPrimFuncConstants() # type: ignore +def LowerAutoCopy(): + """Automatically do memory optimizations for auto copy blocks + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerAutoCopy() # type: ignore + + def RenormalizeSplitPattern(): """Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index da1bbc296a49..1e126ebf7bc0 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -214,6 +214,7 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::UnifyThreadBinding()); pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage()); pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(tir::transform::LowerAutoCopy()); pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::InjectSoftwarePipeline()); pass_list.push_back(tir::transform::LowerOpaqueBlock()); diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc index f0459785f352..3b0ef07f3758 100644 --- a/src/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -309,6 +309,7 @@ Sequential PassListForPerStoreFeature() { tir::transform::ConvertBlocksToOpaque(), tir::transform::UnifyThreadBinding(), tir::transform::CompactBufferAllocation(), + tir::transform::LowerAutoCopy(), tir::transform::LowerMatchBuffer(), tir::transform::Simplify(), }); diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 6f9b46a0f734..0013106b09e8 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -164,6 +164,7 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tir::transform::UnifyThreadBinding()); pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage()); pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(tir::transform::LowerAutoCopy()); pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::InjectSoftwarePipeline()); pass_list.push_back(tir::transform::LowerOpaqueBlock()); diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 5a9dab4854bd..6f6dbc138589 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -629,6 +629,30 @@ BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, return CreateRV(result); } +/******** Schedule: Data movement ********/ + +BlockRV ConcreteScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int read_buffer_index, const String& storage_scope) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::ReadAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), read_buffer_index, + storage_scope); + TVM_TIR_SCHEDULE_END("read-at", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + +BlockRV ConcreteScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int write_buffer_index, const String& storage_scope) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::WriteAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), write_buffer_index, + storage_scope); + TVM_TIR_SCHEDULE_END("write-at", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + /******** Schedule: Compute location ********/ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 82ac9f913374..290b6a4456e4 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -126,6 +126,11 @@ class ConcreteScheduleNode : public ScheduleNode { int cse_thresh) override; BlockRV ReIndex(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type) override; + /******** Schedule: Data movement ********/ + BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) override; + BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) override; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, int index = -1) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 563864229a26..89cdf68a458a 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -339,6 +339,15 @@ TVM_DLL Array CacheIndex(ScheduleState self, const StmtSRef& block_sre */ TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index, BufferIndexType buffer_index_type); + +/******** Schedule: Data movement ********/ + +TVM_DLL StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int read_buffer_index, const String& storage_scope); + +TVM_DLL StmtSRef WriteAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int write_buffer_index, const String& storage_scope); + /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the diff --git a/src/tir/schedule/primitive/read_write_at.cc b/src/tir/schedule/primitive/read_write_at.cc new file mode 100644 index 000000000000..8b7d78f6699e --- /dev/null +++ b/src/tir/schedule/primitive/read_write_at.cc @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { + +using support::NDIntSet; + +bool HasBuffer(const Array& buffer_regions, const Buffer& buffer) { + for (const BufferRegion& buffer_region : buffer_regions) { + if (buffer_region->buffer.same_as(buffer)) { + return true; + } + } + return false; +} + +void RelaxBufferRegions(const Array& buffer_regions, + const Buffer& buffer, // + const Map& var_dom, // + const Map& bindings, // + std::vector* relaxed_regions) { + for (const BufferRegion& buffer_region : buffer_regions) { + if (buffer_region->buffer.same_as(buffer)) { + Array relaxed_region = + arith::EvalSet(Substitute(buffer_region->region, bindings), var_dom); + relaxed_regions->push_back({relaxed_region.begin(), relaxed_region.end()}); + } + } +} + +class ScopeReplacer : public StmtMutator { + public: + static Block Replace(const BlockNode* scope_block, const Buffer& dst, const ForNode* old_loop, + const ForNode* new_loop) { + ObjectPtr new_scope_block = make_object(*scope_block); + new_scope_block->body = ScopeReplacer(old_loop, new_loop)(std::move(new_scope_block->body)); + new_scope_block->alloc_buffers.push_back(dst); + return Block(new_scope_block); + } + + private: + explicit ScopeReplacer(const ForNode* old_loop, const ForNode* new_loop) + : old_loop_(old_loop), new_loop_(new_loop), found_(false) {} + + Stmt VisitStmt(const Stmt& stmt) final { return found_ ? stmt : StmtMutator::VisitStmt(stmt); } + Stmt VisitStmt_(const BlockNode* block) final { return GetRef(block); } + Stmt VisitStmt_(const ForNode* loop) final { + if (loop == old_loop_) { + found_ = true; + return GetRef(new_loop_); + } + return StmtMutator::VisitStmt_(loop); + } + + const ForNode* old_loop_; + const ForNode* new_loop_; + bool found_; +}; + +class ReadWriteAtBufferReplacer : public StmtExprMutator { + public: + explicit ReadWriteAtBufferReplacer(const Buffer& src, const Buffer& dst, + Map* block_sref_reuse) + : src_(src), dst_(dst), block_sref_reuse_(block_sref_reuse) {} + + private: + Stmt VisitStmt_(const BufferStoreNode* _store) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_store)); + if (store->buffer.same_as(src_)) { + ObjectPtr new_store = make_object(*store.get()); + new_store->buffer = dst_; + return BufferStore(new_store); + } + return store; + } + + PrimExpr VisitExpr_(const BufferLoadNode* _load) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); + if (load->buffer.same_as(src_)) { + ObjectPtr new_load = make_object(*load.get()); + new_load->buffer = dst_; + return BufferLoad(new_load); + } + return load; + } + + Stmt VisitStmt_(const BlockNode* _block) final { + Block old_block = GetRef(_block); + Block block = Downcast(StmtExprMutator::VisitStmt_(_block)); + ObjectPtr new_block = make_object(*block.get()); + new_block->reads = ReplaceBuffer(new_block->reads, src_, dst_); + new_block->writes = ReplaceBuffer(new_block->writes, src_, dst_); + block_sref_reuse_->Set(old_block, Block(new_block)); + return Block(new_block); + } + + const Buffer& src_; + const Buffer& dst_; + Map* block_sref_reuse_; +}; + +struct ReadWriteAtImpl { + template + static StmtSRef Main(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int buffer_index, const String& storage_scope, + Map annotations) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + Buffer src = GetNthAccessBuffer(self, GetRef(block), buffer_index, + is_read ? BufferIndexType::kRead : BufferIndexType::kWrite); + Buffer dst = WithScope(src, storage_scope); + ReadWriteAtImpl impl(self, loop_sref, src, dst, annotations); + std::pair new_loop_block = + impl.MakeLoopAndBlock(src->name + "_" + storage_scope); + StmtSRef result_block_sref = + impl.ReplaceScopeBlock(new_loop_block.first.get(), new_loop_block.second->block.get()); + impl.UpdateBlockInfo(result_block_sref, !new_loop_block.second->iter_values.empty()); + return result_block_sref; + } + + private: + static Map GetLoopDomain(const StmtSRefNode* loop_sref) { + Map result; + for (const ForNode* loop; (loop = loop_sref->StmtAs()) != nullptr; + loop_sref = loop_sref->parent) { + result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + return result; + } + + StmtSRef ReplaceScopeBlock(const ForNode* new_loop, const BlockNode* new_block) { + StmtSRef scope_root_sref = GetScopeRoot(self_, loop_sref_, + /*require_stage_pipeline=*/true); + const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_root_sref); + Block new_scope_block = ScopeReplacer::Replace(scope_block, dst_, loop_, new_loop); + block_sref_reuse_.Set(GetRef(scope_block), new_scope_block); + self_->Replace(scope_root_sref, new_scope_block, block_sref_reuse_); + return self_->stmt2ref.at(new_block); + } + + void UpdateBlockInfo(const StmtSRef& new_block_sref, bool affine_binding) { + BlockInfo& block_info = self_->block_info[new_block_sref]; + block_info.affine_binding = affine_binding; + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + } + + template + std::pair MakeLoopAndBlock(const String& new_block_name_hint) { + Array subtrees = AsArray(loop_->body); + int n_subtrees = subtrees.size(); + runtime::StorageScope scope = runtime::StorageScope::Create(dst_.scope()); + std::vector relaxed_regions; + std::vector r_pos; + std::vector w_pos; + relaxed_regions.reserve(n_subtrees); + r_pos.reserve(n_subtrees); + w_pos.reserve(n_subtrees); + // Step 1. Iterate over all subtrees + for (int i = 0; i < n_subtrees; ++i) { + bool r_visited = false; + bool w_visited = false; + auto f_visit = [this, &relaxed_regions, &r_visited, &w_visited, + &scope](const ObjectRef& obj) -> bool { + const BlockRealizeNode* realize = obj.as(); + if (realize == nullptr) { + return true; + } + const BlockNode* block = realize->block.get(); + bool has_r = HasBuffer(block->reads, src_); + bool has_w = HasBuffer(block->writes, src_); + r_visited = r_visited || has_r; + w_visited = w_visited || has_w; + if (is_read ? has_r : has_w) { + RelaxBufferRegions( + /*buffer_regions=*/is_read ? block->reads : block->writes, + /*buffer=*/src_, + /*var_dom=*/ + arith::AsIntSet(LoopDomainOfSRefTreePath( + /*low_inclusive=*/GetRef(self_->stmt2ref.at(block)->parent), + /*high_exclusive=*/loop_sref_, + /*extra_relax_scope=*/scope)), + /*bindings=*/GetBindings(GetRef(realize)), + /*relaxed_regions=*/&relaxed_regions); + } + return false; + }; + PreOrderVisit(subtrees[i], f_visit); + if (r_visited) { + r_pos.push_back(i); + } + if (w_visited) { + w_pos.push_back(i); + } + } + // Step 2. Calculate `insert_pos` and [st, ed) for buffer replacement + int insert_pos = -1, st = -1, ed = -1; + if (is_read) { + ICHECK(!r_pos.empty()); + // No write after the first read + ICHECK(w_pos.empty() || w_pos.back() < r_pos.front()); + // Can be inserted at [0, r_pos.front()], i.e. before the first read + insert_pos = r_pos.front(); + // Buffer reads in [insert_pos, +oo) is rewritten + st = insert_pos; + ed = n_subtrees; + } else { + ICHECK(!w_pos.empty()); + // No read after the last write + ICHECK(r_pos.empty() || r_pos.back() <= w_pos.back()); + // Can be inserted into (w_pos.back(), +oo), i.e. after the last write + insert_pos = w_pos.back() + 1; + st = 0; + ed = insert_pos; + } + // Step 3. Calculate `domain`, the domain of buffer access + NDIntSet relaxed = support::NDIntSetUnion(relaxed_regions); + int ndim = relaxed.size(); + Array domain; + domain.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + const arith::IntSet& int_set = relaxed[i]; + PrimExpr min = analyzer_->Simplify(int_set.min()); + PrimExpr extent = analyzer_->Simplify(int_set.max() + 1 - min); + domain.push_back(Range::FromMinExtent(min, extent)); + } + // Step 4. Insert the auto copy block and replace buffers + ReadWriteAtBufferReplacer replacer(src_, dst_, &block_sref_reuse_); + for (int i = st; i < ed; ++i) { + Stmt stmt = subtrees[i]; + subtrees.Set(i, Stmt(nullptr)); + subtrees.Set(i, replacer(std::move(stmt))); + } + BlockRealize realize = + is_read + ? MakeBlock(src_, dst_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain) + : MakeBlock(dst_, src_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain); + subtrees.insert(subtrees.begin() + insert_pos, realize); + ObjectPtr new_loop = make_object(*loop_); + new_loop->body = SeqStmt(std::move(subtrees)); + return {For(new_loop), realize}; + } + + BlockRealize MakeBlock(const Buffer& copy_from, const Buffer& copy_to, const String& name_hint, + const Map& loop_domain, Array domain) const { + int n = domain.size(); + std::vector loop_vars; + loop_vars.reserve(n); + for (int i = 0; i < n; ++i) { + loop_vars.push_back(Var("ax" + std::to_string(i))); + } + Map bindings; + Array iter_vars; + Array iter_values; + Array indices; + iter_vars.reserve(n); + iter_values.reserve(n); + indices.reserve(n); + for (int i = 0; i < n; ++i) { + auto f_substitute = [&loop_domain, &bindings, &iter_vars, + &iter_values](const Var& var) -> Optional { + auto it = bindings.find(var); + if (it != bindings.end()) { + return (*it).second; + } + Range range = loop_domain.at(var); + ObjectPtr v = make_object(*var.get()); + v->name_hint = "v" + std::to_string(iter_vars.size()); + bindings.Set(var, Var(v)); + iter_values.push_back(var); + iter_vars.push_back(IterVar(range, Var(v), IterVarType::kDataPar)); + return Var(v); + }; + ObjectPtr dom = make_object(*domain[i].get()); + dom->min = Substitute(std::move(dom->min), f_substitute); + dom->extent = Substitute(std::move(dom->extent), f_substitute); + domain.Set(i, Range(dom)); + } + for (int i = 0; i < n; ++i) { + indices.push_back(domain[i]->min + loop_vars[i]); + } + Stmt stmt = BufferStore(copy_to, /*value=*/BufferLoad(copy_from, indices), /*indices=*/indices); + for (int i = n - 1; i >= 0; --i) { + stmt = For(loop_vars[i], Integer(0), domain[i]->extent, ForKind::kSerial, stmt); + } + return BlockRealize( + /*values=*/iter_values, + /*predicate=*/const_true(), + Block(/*iter_vars=*/iter_vars, + /*reads=*/{BufferRegion(copy_from, domain)}, + /*writes=*/{BufferRegion(copy_to, domain)}, + /*name_hint=*/name_hint, // + /*body=*/std::move(stmt), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/annotations_)); + } + + explicit ReadWriteAtImpl(ScheduleState self, const StmtSRef& loop_sref, const Buffer& src, + const Buffer& dst, Map annotations) + : self_(self), + loop_sref_(loop_sref), + loop_(nullptr), + src_(src), + dst_(dst), + annotations_(annotations), + block_sref_reuse_(), + analyzer_(std::make_unique()) { + loop_ = TVM_SREF_TO_FOR(loop_sref); + } + + ScheduleState self_; + const StmtSRef& loop_sref_; + const ForNode* loop_; + const Buffer& src_; + const Buffer& dst_; + Map annotations_; + Map block_sref_reuse_; + std::unique_ptr analyzer_; +}; + +StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int read_buffer_index, const String& storage_scope) { + return ReadWriteAtImpl::Main(self, loop_sref, block_sref, read_buffer_index, storage_scope, + {{tir::attr::auto_copy, Integer(1)}}); +} + +StmtSRef WriteAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int write_buffer_index, const String& storage_scope) { + return ReadWriteAtImpl::Main(self, loop_sref, block_sref, write_buffer_index, + storage_scope, {{tir::attr::auto_copy, Integer(1)}}); +} + +/******** Instruction Registration ********/ + +struct ReadAtTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReadAt"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int buffer_index, const String& storage_scope); + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, BlockRV block, + Integer read_buffer_index, String storage_scope) { + return sch->ReadAt(loop, block, read_buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String loop, String block, + Integer read_buffer_index, String storage_scope) { + PythonAPICall py("read_at"); + py.Input("loop", loop); + py.Input("block", block); + py.Input("read_buffer_index", read_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct WriteAtTraits : public UnpackedInstTraits { + static constexpr const char* kName = "WriteAt"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, BlockRV block, + Integer write_buffer_index, String storage_scope) { + return sch->WriteAt(loop, block, write_buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String loop, String block, + Integer write_buffer_index, String storage_scope) { + PythonAPICall py("write_at"); + py.Input("loop", loop); + py.Input("block", block); + py.Input("write_buffer_index", write_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(ReadAtTraits); +TVM_REGISTER_INST_KIND_TRAITS(WriteAtTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index cb8b5a1d7787..e43148a7eb10 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -192,6 +192,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReIndex") int buffer_index_type) { return self->ReIndex(block_rv, buffer_index, static_cast(buffer_index_type)); }); +/******** (FFI) Data movement ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReadAt").set_body_method(&ScheduleNode::ReadAt); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleWriteAt") + .set_body_method(&ScheduleNode::WriteAt); /******** (FFI) Compute location ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt") .set_body_method(&ScheduleNode::ComputeAt); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index a5cb66a0cb44..7c8c800ca748 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -384,6 +384,34 @@ BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index, return result; } +/******** Schedule: Data movement ********/ + +BlockRV TracedScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int read_buffer_index, const String& storage_scope) { + BlockRV result = + ConcreteScheduleNode::ReadAt(loop_rv, block_rv, read_buffer_index, storage_scope); + + static const InstructionKind& kind = InstructionKind::Get("ReadAt"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv, block_rv}, + /*attrs=*/{Integer(read_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} + +BlockRV TracedScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int write_buffer_index, const String& storage_scope) { + BlockRV result = + ConcreteScheduleNode::WriteAt(loop_rv, block_rv, write_buffer_index, storage_scope); + + static const InstructionKind& kind = InstructionKind::Get("WriteAt"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv, block_rv}, + /*attrs=*/{Integer(write_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} + /******** Schedule: Compute location ********/ void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 1fcba9806380..7854adad39cb 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -86,6 +86,11 @@ class TracedScheduleNode : public ConcreteScheduleNode { BufferIndexType buffer_index_type) final; Array CacheIndex(const BlockRV& block_rv, const String& storage_scope, int cse_thresh) final; + /******** Schedule: Data movement ********/ + BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) final; + BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) final; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops, int index = -1) final; diff --git a/src/tir/transforms/memhammer_coalesce.cc b/src/tir/transforms/memhammer_coalesce.cc new file mode 100644 index 000000000000..5ca20f57aa78 --- /dev/null +++ b/src/tir/transforms/memhammer_coalesce.cc @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../../runtime/thread_storage_scope.h" +#include "./memhammer_rewrite_rule.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Fuse consecutive loops + * \param body the outer-most loop + * \return the fused loop + */ +Stmt FuseNestLoops(Stmt body) { + std::vector loops; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + std::string suffix; + int n = loops.size(); + for (int i = 1; i < n; i++) { + suffix += "_" + loops[i]->loop_var->name_hint; + } + suffix += "_fused"; + Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix); + Map subst_map; + PrimExpr tot = fused_var; + for (int i = n - 1; i >= 0; i--) { + subst_map.Set(loops[i]->loop_var, floormod(tot, loops[i]->extent)); + tot = floordiv(tot, loops[i]->extent); + } + auto f_substitute = [&](const Var& v) -> Optional { + return subst_map.Get(v).value_or(v); + }; + PrimExpr fused_extent = 1; + for (int i = 0; i < n; i++) { + fused_extent *= loops[i]->extent; + } + return For(fused_var, 0, fused_extent, ForKind::kSerial, + Substitute(std::move(body), f_substitute)); +} + +/*! + * \brief a combination of split, bind, vectorize, + * a helper function to perform coalesced load/store + * \param stmt the stmt to do transformation + * \param constraints The constraints, including thread extents, vector bytes, and data bits. + * \return The stmt after transformation + */ +Stmt SplitBindVectorize(const Stmt& stmt, const ConstraintSet& constraints) { + const ForNode* loop = TVM_TYPE_AS(stmt, ForNode); + int loop_extent = Downcast(loop->extent)->value; + int vector_bytes = constraints.vector_bytes; + int data_bits = constraints.data_bits; + int vector_len = std::max(1, vector_bytes * 8 / data_bits); + int tot_threads = 1; + // generate thread binding loops + std::vector factors{-1}; + std::vector thread_axis; + if (Optional o_t = constraints.thread_extent.Get("threadIdx.z")) { + int t = o_t.value()->value; + tot_threads *= t; + factors.push_back(t); + thread_axis.push_back("threadIdx.z"); + } + if (Optional o_t = constraints.thread_extent.Get("threadIdx.y")) { + int t = o_t.value()->value; + tot_threads *= t; + factors.push_back(t); + thread_axis.push_back("threadIdx.y"); + } + if (Optional o_t = constraints.thread_extent.Get("threadIdx.x")) { + int t = o_t.value()->value; + tot_threads *= t; + factors.push_back(t); + thread_axis.push_back("threadIdx.x"); + } + // generate vectorized loop + factors.push_back(vector_len); + // generate outer loop + factors[0] = (loop_extent + tot_threads * vector_len - 1) / (tot_threads * vector_len); + // create new loop vars + int n = factors.size(); + std::vector new_loop_vars; + new_loop_vars.reserve(n); + arith::Analyzer analyzer; + for (int i = 0; i < n; i++) { + const PrimExpr& factor = factors[i]; + Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i)); + analyzer.Bind(var, Range::FromMinExtent(0, factor)); + new_loop_vars.push_back(var); + } + // substitute fused loop var with new loop vars + PrimExpr substitute_value = 0; + for (int i = 0; i < n; i++) { + substitute_value *= factors[i]; + substitute_value += new_loop_vars[i]; + } + // Construct the new loop nest + Stmt body = Substitute(loop->body, [&](const Var& v) -> Optional { + if (v.same_as(loop->loop_var)) { + return substitute_value; + } else { + return NullOpt; + } + }); + PrimExpr predicate = substitute_value < loop->extent; + if (!analyzer.CanProve(predicate)) { + body = IfThenElse(predicate, body); + } + body = For(new_loop_vars.back(), 0, vector_len, ForKind::kVectorized, std::move(body)); + for (int i = n - 2; i >= 1; i--) { + body = For(new_loop_vars[i], 0, factors[i], ForKind::kThreadBinding, std::move(body), + IterVar(Range(nullptr), Var(thread_axis[i - 1]), kThreadIndex, thread_axis[i - 1])); + } + return For(new_loop_vars[0], 0, factors[0], ForKind::kSerial, std::move(body)); +} + +Stmt CoalescedAccess::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt after_fuse = FuseNestLoops(stmt); + Stmt after_split = SplitBindVectorize(std::move(after_fuse), constraints); + return after_split; +} + +/*! + * \brief Get the index mapping of a specific stmt. + * The stmt is like: + * for i0: + * ... + * for in: + * A[f(i0, ..., in])] = B[i0, ..., in], + * where f is the index mapping we want to get. + * \param constraints The constraints, including the write region that is required to calculate + * the index mapping + * \return The mapping in the form of j0, ..., jm, where j0, ... jm = f(i0, ..., in) + */ +Array GetMapping(const Stmt& stmt, const ConstraintSet& constraints) { + Stmt body = stmt; + while (const ForNode* loop = body.as()) { + body = loop->body; + } + const BufferStoreNode* buf_store = TVM_TYPE_AS(body, BufferStoreNode); + BufferRegion write_region = constraints.write_region; + const Array& write_index = buf_store->indices; + ICHECK(write_region->region.size() == write_index.size() && + write_region->buffer.same_as(buf_store->buffer)); + Array result; + arith::Analyzer analyzer; + for (int i = 0; i < static_cast(write_region->region.size()); i++) { + PrimExpr pattern = analyzer.Simplify(write_index[i] - write_region->region[i]->min); + if (!is_zero(pattern)) { + result.push_back(pattern); + } + } + return result; +} + +Stmt InverseMapping::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt body = stmt; + Map var_range; + Array loop_vars; + // Step 1. Get index mapping + Array mapping_pattern = GetMapping(stmt, constraints); + while (const ForNode* loop = body.as()) { + var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + loop_vars.push_back(loop->loop_var); + body = loop->body; + } + // Step 2. Get Inverse mapping + arith::Analyzer analyzer; + DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); + auto iter_map = + arith::DetectIterMap(mapping_pattern, var_range, Bool(true), arith::Bijective, &analyzer); + CHECK_EQ(iter_map->indices.size(), loop_vars.size()); + Map inverse_mapping = arith::InverseAffineIterMap(iter_map->indices, loop_vars); + // Step 3. Generate new body + BufferRegion read_region = constraints.read_region; + BufferRegion write_region = constraints.write_region; + Array write_index; + Array read_index; + Array new_loop_vars; + Map substitute_map; + // Step 3.1 construct target buffer indices + for (int i = 0, j = 0; i < static_cast(write_region->region.size()); i++) { + if (is_one(write_region->region[i]->extent)) { + write_index.push_back(write_region->region[i]->min); + } else { + Var var = runtime::Downcast(loop_vars[j]).copy_with_suffix("_inverse"); + new_loop_vars.push_back(var); + substitute_map.Set(runtime::Downcast(loop_vars[j++]), var); + write_index.push_back(write_region->region[i]->min + var); + } + } + // Step 3.2 construct source buffer indices + for (int i = 0, j = 0; i < static_cast(read_region->region.size()); i++) { + if (is_one(read_region->region[i]->extent)) { + read_index.push_back(read_region->region[i]->min); + } else { + read_index.push_back( + read_region->region[i]->min + + Substitute(inverse_mapping[Downcast(loop_vars[j++])], substitute_map)); + } + } + BufferLoad new_buf_load = BufferLoad(read_region->buffer, read_index); + BufferStore new_buf_store = BufferStore(write_region->buffer, new_buf_load, write_index); + Stmt ret = new_buf_store; + // Step 3.3 construct loop body + for (int i = static_cast(new_loop_vars.size()) - 1; i >= 0; i--) { + PrimExpr extent = write_region->region[i]->extent; + ret = For(new_loop_vars[i], 0, extent, ForKind::kSerial, std::move(ret)); + } + return ret; +} +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/memhammer_intermediate_stage.cc b/src/tir/transforms/memhammer_intermediate_stage.cc new file mode 100644 index 000000000000..fb60b0bf2460 --- /dev/null +++ b/src/tir/transforms/memhammer_intermediate_stage.cc @@ -0,0 +1,444 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "memhammer_rewrite_rule.h" + +namespace tvm { +namespace tir { + +Stmt CopyLoopChain(const std::vector loops, const Stmt& inner_body, int ith = -1, + Stmt* ith_loop = nullptr) { + Stmt ret = inner_body; + for (int i = static_cast(loops.size() - 1); i >= 0; i--) { + ObjectPtr new_loop = make_object(*loops[i]); + new_loop->body = ret; + ret = For(new_loop); + if (ith == i) { + *ith_loop = ret; + } + } + return ret; +} + +/*! + * \brief lift all the thread binding loops + * \param stmt the top loop + * \return a pair. The first is the transformed stmt. + * The second is the lowest thread binding loop. + */ +std::pair LiftThreadBindingLoops(Stmt stmt) { + std::vector normal_loops; + std::vector thread_binding_loops; + Stmt body = stmt; + while (const ForNode* loop = body.as()) { + if (loop->kind == ForKind::kThreadBinding) { + thread_binding_loops.push_back(loop); + } else { + normal_loops.push_back(loop); + } + body = loop->body; + } + body = CopyLoopChain(normal_loops, body); + For compute_location; + body = CopyLoopChain(thread_binding_loops, body, + static_cast(thread_binding_loops.size()) - 1, &compute_location); + + return std::make_pair(body, compute_location); +} + +/*! + * \brief Analyze the access pattern for buffer rank promotion. + * Rank promotion is a transformation that reshapes the buffer + * but doesn't change its underlying data layout. + * After the reshape, we expect that all dimensions of the access indices + * will be in the form of floormod(floordiv(x, a), b). + * Rank promotion removes strided access, thus enabling further buffer compacting + */ +class IndexPatternFinder : public ExprVisitor { + public: + IndexPatternFinder(const Map& var_range, Array* resulting_index) + : var_range_(var_range), resulting_index_(resulting_index) {} + struct Operator { + enum class OpKind { Mul, FloorDiv, FloorMod }; + OpKind kind; + int64_t operand; + }; + + /*! + * \brief Calculate the new buffer shape after rank promotion. + * For each dimension of original shape, it will be compacted. + * \param indices The access indices of the buffer + * \param var_range The iter range of the vars in the indices + * \param rewrite_indices The access indices after rank promotion + * \return The new buffer shape after rank promotion. + */ + static Array getRankPromotedShape(Array indices, + const Map& var_range, + Array* rewrite_indices) { + Map var_dom = arith::AsIntSet(var_range); + Array new_shape; + for (const PrimExpr& expr : indices) { + Array indices_dim; + IndexPatternFinder extractor(var_range, &indices_dim); + extractor(expr); + if (!extractor.success_) { + return {}; + } + Array access_shape = extractor.access_shape_; + PrimExpr product_shape = 1; + for (PrimExpr e : access_shape) { + product_shape *= e; + } + new_shape.push_back(product_shape); + PrimExpr flatten_index = 0; + for (int i = 0; i < static_cast(access_shape.size()); i++) { + flatten_index = flatten_index * access_shape[i] + indices_dim[i]; + } + rewrite_indices->push_back(flatten_index); + } + return new_shape; + } + + private: + void VisitExpr_(const VarNode* op) final { + if (!success_) { + return; + } + if (Optional range = var_range_.Get(GetRef(op))) { + PrimExpr index = GetRef(op); + int64_t max = range.value()->extent.as()->value; + int64_t extent = max; + for (int i = static_cast(operator_stack.size()) - 1; i >= 0; i--) { + Operator o = operator_stack[i]; + switch (o.kind) { + case Operator::OpKind::Mul: + max *= o.operand; + index = index * Integer(o.operand); + break; + case Operator::OpKind::FloorDiv: + if (max % o.operand != 0 && o.operand % max != 0) { + success_ = false; + return; + } + max = max / o.operand; + if (extent > max) { + extent = std::max(static_cast(1), max); + } + if (max % extent != 0) { + success_ = false; + return; + } + index = floordiv(index, Integer(o.operand)); + break; + case Operator::OpKind::FloorMod: + int64_t step = max / extent; + if (step % o.operand != 0 && o.operand % step != 0) { + success_ = false; + return; + } + if (step % o.operand == 0) { + extent = 1; + max = 0; + } else { + extent = std::max(static_cast(1), std::min(extent, o.operand / step)); + max = extent * step; + } + index = floormod(index, Integer(o.operand)); + } + } + if (extent > 1) { + ICHECK(max % extent == 0); + access_shape_.push_back(Integer(extent)); + resulting_index_->push_back(floordiv(index, max / extent)); + } + } + } + + void VisitExpr_(const FloorDivNode* op) final { + int64_t b = op->b.as()->value; + operator_stack.push_back(Operator{Operator::OpKind::FloorDiv, b}); + ExprVisitor::VisitExpr_(op); + operator_stack.pop_back(); + } + + void VisitExpr_(const FloorModNode* op) final { + int64_t b = op->b.as()->value; + operator_stack.push_back(Operator{Operator::OpKind::FloorMod, b}); + ExprVisitor::VisitExpr_(op); + operator_stack.pop_back(); + } + + void VisitExpr_(const MulNode* op) final { + int64_t b = op->b.as()->value; + operator_stack.push_back(Operator{Operator::OpKind::Mul, b}); + ExprVisitor::VisitExpr_(op); + operator_stack.pop_back(); + } + + Map var_range_; + Array access_shape_; + Array* resulting_index_; + std::vector operator_stack; + bool success_ = true; +}; + +class BufferLoadReplacer : public StmtExprMutator { + public: + BufferLoadReplacer(const Buffer& tgt_buffer, const BufferLoad& new_buffer_load) + : tgt_buffer_(tgt_buffer), new_buffer_load_(new_buffer_load) {} + + PrimExpr VisitExpr_(const BufferLoadNode* op) { + if (op->buffer.same_as(tgt_buffer_)) { + return new_buffer_load_; + } + return StmtExprMutator::VisitExpr_(op); + } + + private: + Buffer tgt_buffer_; + BufferLoad new_buffer_load_; +}; + +/*! + * \brief Insert a cache stage to the compute location + * \param stmt the stmt + * \param is_write_cache whether to write a read cache or write cache + * \param storage_scope the storage scope of the new cache + * \param compute_location the compute location. + * \param outer_loops the outer loops of this stmt + * \param alloc_buffer the new cache block + * \return a pair. The first is the stmt after transformation. + * The second is the SeqStmt that contains 2 stages (one original and another inserted). + */ +std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope, + Optional compute_location, + const Array& outer_loops, Buffer* alloc_buffer) { + Stmt body = stmt; + std::vector loops; + std::vector loops_under_compute_location; + std::vector relaxed_thread_loops; + bool need_relax = !compute_location.defined(); + Map var_range; + PrimExpr vector_bytes = -1; + // Step 1. Perform rank promotion on the buffer access, turning a strided-changing dimension into + // several contiguous-changing dimensions + // Step 1.1 collect loop var range for rank promotion + while (const ForNode* loop = body.as()) { + if (need_relax) { + var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + loops_under_compute_location.push_back(loop); + } else { + loops.push_back(loop); + } + if (loop == compute_location.value_or(For()).get()) { + need_relax = true; + } + if (loop->kind == ForKind::kVectorized) { + vector_bytes = loop->extent; + } + body = loop->body; + } + Optional predicate; + if (const auto* op = body.as()) { + // the predicate is generated by coalescing + predicate = op->condition; + body = op->then_case; + } + for (const For& loop : outer_loops) { + if (loop->kind == ForKind::kThreadBinding) { + const String& thread_tag = loop->thread_binding.value()->thread_tag; + if (CanRelaxStorageUnderThread(runtime::StorageScope::Create(storage_scope), + runtime::ThreadScope::Create(thread_tag))) { + var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + relaxed_thread_loops.push_back(loop.get()); + } + } + } + + arith::Analyzer analyzer; + const BufferLoadNode* target_buffer_load = nullptr; + if (is_write_cache) { + tir::PreOrderVisit(stmt, [&](const ObjectRef& obj) { + if (const auto* buffer_load = obj.as()) { + if (buffer_load->buffer.scope() == "wmma.accumulator") { + if (target_buffer_load == nullptr) { + target_buffer_load = buffer_load; + } else { + CHECK(target_buffer_load->buffer.same_as(buffer_load->buffer)) + << "More than one target buffer found"; + ICHECK(target_buffer_load->indices.size() == buffer_load->indices.size()); + for (size_t i = 0; i < target_buffer_load->indices.size(); i++) { + CHECK( + analyzer.CanProveEqual(target_buffer_load->indices[i], buffer_load->indices[i])); + } + } + } + } + return true; + }); + CHECK(target_buffer_load); + } + + const BufferStoreNode* buf_store = TVM_TYPE_AS(body, BufferStoreNode); + Array cache_indices; + Array new_shape; + bool use_rank_promotion = false; + if (!is_write_cache && buf_store->value.as()) { + Array indices = + is_write_cache ? buf_store->indices : buf_store->value.as()->indices; + new_shape = IndexPatternFinder::getRankPromotedShape(indices, var_range, &cache_indices); + // write cache disabled for now + // rank promotion for write cache cannot guarantee the shape fits wmma.accumulator + if (!new_shape.empty()) { + use_rank_promotion = true; + } + } + Array new_loop_vars; + Map subst_map; + if (!use_rank_promotion) { + cache_indices.clear(); + for (const ForNode* loop : relaxed_thread_loops) { + new_shape.push_back(loop->extent); + } + for (const ForNode* loop : loops_under_compute_location) { + new_shape.push_back(loop->extent); + } + } + + for (int i = 0; i < static_cast(relaxed_thread_loops.size()); i++) { + const ForNode* loop = relaxed_thread_loops[i]; + Var new_loop_var = loop->loop_var.copy_with_suffix("_cache"); + new_loop_vars.push_back(new_loop_var); + subst_map.Set(loop->loop_var, new_loop_var); + if (!use_rank_promotion) { + cache_indices.push_back(loop->loop_var); + } + } + for (int i = 0; i < static_cast(loops_under_compute_location.size()); i++) { + const ForNode* loop = loops_under_compute_location[i]; + Var new_loop_var = loop->loop_var.copy_with_suffix("_cache"); + new_loop_vars.push_back(new_loop_var); + subst_map.Set(loop->loop_var, new_loop_var); + if (!use_rank_promotion) { + cache_indices.push_back(loop->loop_var); + } + } + Array subst_indices; + Array subst_cache_indices; + if (is_write_cache) { + for (PrimExpr e : buf_store->indices) { + subst_indices.push_back(Substitute(e, subst_map)); + } + } + for (PrimExpr e : cache_indices) { + subst_cache_indices.push_back(Substitute(e, subst_map)); + } + + Buffer new_buffer; + if (is_write_cache) { + // this is needed for global <- cast(load(wmma)) + // shared stage should have the same dtype as wmma + new_buffer = WithScope(target_buffer_load->buffer, storage_scope); + } else { + new_buffer = WithScope(buf_store->buffer, storage_scope); + } + BufferNode* buffer_ptr = new_buffer.CopyOnWrite(); + buffer_ptr->shape = new_shape; + *alloc_buffer = new_buffer; + + Stmt generate_body; + if (is_write_cache) { + // copy from wmma to new cache buffer + BufferLoad new_buffer_load{new_buffer, cache_indices}; + generate_body = + BufferLoadReplacer(target_buffer_load->buffer, new_buffer_load)(GetRef(buf_store)); + generate_body = Substitute(generate_body, subst_map); + } else { + generate_body = + BufferStore(new_buffer, Substitute(buf_store->value, subst_map), subst_cache_indices); + } + + if (predicate.defined()) { + // generated by coalescing + CHECK_EQ(loops_under_compute_location.size(), 2); + PrimExpr subst_value = 0; + PrimExpr subst_predicate = Substitute(predicate.value(), subst_map); + generate_body = IfThenElse(subst_predicate, generate_body); + } + + for (int i = static_cast(loops_under_compute_location.size()) - 1; i >= 0; i--) { + const ForNode* orig_loop = loops_under_compute_location[i]; + ObjectPtr new_loop = make_object(*orig_loop); + new_loop->loop_var = new_loop_vars[i + relaxed_thread_loops.size()]; + new_loop->body = generate_body; + generate_body = For(new_loop); + } + for (int i = static_cast(relaxed_thread_loops.size()) - 1; i >= 0; i--) { + const ForNode* orig_loop = relaxed_thread_loops[i]; + ObjectPtr new_loop = make_object(*orig_loop); + new_loop->loop_var = new_loop_vars[i]; + new_loop->body = generate_body; + new_loop->kind = ForKind::kSerial; + new_loop->thread_binding = NullOpt; + new_loop->annotations = {}; + generate_body = For(new_loop); + } + Stmt rewrite_body; + if (is_write_cache) { + BufferLoad new_buffer_load{new_buffer, cache_indices}; + rewrite_body = BufferStore(new_buffer, GetRef(target_buffer_load), cache_indices); + } else { + rewrite_body = + BufferStore(buf_store->buffer, BufferLoad(new_buffer, cache_indices), buf_store->indices); + } + if (predicate.defined()) { + rewrite_body = IfThenElse(predicate.value(), rewrite_body); + } + for (int i = static_cast(loops_under_compute_location.size()) - 1; i >= 0; i--) { + const ForNode* orig_loop = loops_under_compute_location[i]; + ObjectPtr new_loop = make_object(*orig_loop); + new_loop->body = rewrite_body; + rewrite_body = For(new_loop); + } + SeqStmt insert_location; + if (is_write_cache) { + generate_body = insert_location = SeqStmt({rewrite_body, generate_body}); + } else { + generate_body = insert_location = SeqStmt({generate_body, rewrite_body}); + } + generate_body = CopyLoopChain(loops, generate_body); + return std::make_pair(generate_body, insert_location); +} + +Stmt CreateLocalStage::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt body; + For compute_location; + std::tie(body, compute_location) = LiftThreadBindingLoops(std::move(stmt)); + Buffer cache_buffer; + Stmt after_caching = InsertCacheStage(body, false, "local", compute_location, + constraints.outer_loops, &cache_buffer) + .first; + if (cache_buffer.defined()) { + output->alloc_buffer.push_back(cache_buffer); + } + return after_caching; +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/memhammer_lower_auto_copy.cc b/src/tir/transforms/memhammer_lower_auto_copy.cc new file mode 100644 index 000000000000..1446dca308a8 --- /dev/null +++ b/src/tir/transforms/memhammer_lower_auto_copy.cc @@ -0,0 +1,779 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../../runtime/thread_storage_scope.h" +#include "../schedule/utils.h" +#include "./ir_utils.h" +#include "./memhammer_rewrite_rule.h" +#include "tvm/tir/stmt.h" + +namespace tvm { +namespace tir { + +using support::NDIntSet; + +// rewrite rules +static InverseMapping inverse_mapping; +static CoalescedAccess coalesced_access; +static CreateLocalStage create_local_stage; +static SharedToWmma shared_to_wmma; +static WmmaToGlobal wmma_to_global; +static WmmaToShared wmma_to_shared; + +/*! + * \brief A class to perform auto padding. + * + * One simple way to perform auto padding is to fix each padding size for each dimension at the + * same time, calculate the precise access index and the bank conflict, + * and choose the one with minimal conflict. However, this algorithm has exponential complexity. + * Suppose we have d dimensions and the padding size is 0-31, we need to calculate bank + * conflict for 32^{d-1} times. + * We propose a fast incremental algorithm that works for affine inputs, and it only calculate + * bank conflict for 32*{d-1} times. To be specific, we first decide the optimal padding size for + * dimension d-2, then for dimension d-3, ..., finally for dimension 0. It involves 2 steps. + * + * First, we analyze how a typical warp accesses the shared memory banks. + * A typical warp means setting all irrelevant loop vars to 0, and only keeps the threads in a warp. + * For each dimension, the access index is represented by + * x_1 * scale_1 + ... + x_n * scale_n (x_i is loop var) + * Note: The affine property guarantees that {x_i} must be independent, + * otherwise the algorithm is wrong. + * We will use this information to keep a list for each dimension called "iteration space" that + * records the resulting index as x_i takes each possible value. + * + * For example, the index is [outer*2+ty, tx*4+vec], where tx is threadIdx.x, and ty is threadIdx.y. + * tx is in [0, 16), and ty is in [0, 2). + * We will first get a warp access [ty, tx*4] because outer and vec are irrelevant loop vars. + * It's obvious that ty, tx*4 are both in the form of x_1 * scale_1 + ... + x_n * scale_n. + * In this case, we will keep lists {{0, 1}, {0, 4, ..., 60}} + * + * Next, we choose a padding size that has minimal conflict from the last dimension to first one. + * To calculate the conflict, we calculate the Cartesian product of the iteration space of all + * dimensions not higher than this. Each single point of product space represents access index + * of a particular thread, by which we can calculate the accessed memory bank. The conflict is + * the highest access frequency among the banks. + * + */ +class AutoPadder { + public: + /** + * \brief Do padding to the given buffers in shard memory + * \param buffers the given buffers + * \return the list of new padded buffers + */ + Array PadSharedMemory(const Array& buffers) { + Array result; + + for (const Buffer& buffer : buffers) { + runtime::StorageScope scope = runtime::StorageScope::Create(buffer.scope()); + if (scope.rank == runtime::StorageRank::kShared) { + auto iter_spaces = iter_spaces_[buffer.get()]; + if (iter_spaces.empty()) { + result.push_back(buffer); + continue; + } + // The access index represented by points in the cartesian product of lower dimension + // iteration spaces + std::vector> low_dim_iter_space(iter_spaces.size(), std::vector()); + + int n = buffer->shape.size(); + int data_bits = buffer->dtype.bits(); + // Step 1. initialize `low_dim_iter_space` with the iteration space of the last dim + for (int i = 0; i < static_cast(iter_spaces.size()); i++) { + auto last_dim_iter_space = iter_spaces[i][n - 1]; + low_dim_iter_space[i] = last_dim_iter_space; + } + PrimExpr stride = 1; + Array reverse_strides; + int pad_min = padding_min_.Get(buffer).value_or(Integer(1)).IntValue(); + // Step 2. For each dimension, select a padding that has minimal bank conflict + for (int k = n - 2; k >= 0; k--) { // dims + int max_pad_size = + std::min(static_cast(max_pad_factor_ * + (stride * buffer->shape[k + 1]).as()->value), + 32 * 32 / data_bits); + int min_conflict = INT32_MAX; + int min_conflict_pad = -1; + for (int pad = 0; pad <= max_pad_size; pad += pad_min) { // select padding + int padded_stride = ((stride * buffer->shape[k + 1]).as()->value + pad) % + (32 * 32 / data_bits); + int conflict = 0; + for (int i = 0; i < static_cast(iter_spaces.size()); i++) { // accesses + auto iter_space = iter_spaces[i][k]; + int bank[32]{0}; + for (int v1 : iter_space) { + for (int v2 : low_dim_iter_space[i]) { + int comb = (v1 * padded_stride + v2) * data_bits / 32 % 32; + bank[comb]++; + } + } + for (int j = 0; j < 32; j++) { + conflict = std::max(conflict, bank[j]); + } + } + if (conflict < min_conflict) { + min_conflict = conflict; + min_conflict_pad = pad; + } + } + // update low_dim_iter_space with + for (int i = 0; i < static_cast(iter_spaces.size()); i++) { // accesses + auto iter_space = iter_spaces[i][k]; + if (!iter_space.empty()) { + int padded_stride = + ((stride * buffer->shape[k + 1]).as()->value + min_conflict_pad) % + (32 * 32 / data_bits); + std::vector span; + for (int v1 : iter_space) { + for (int v2 : low_dim_iter_space[i]) { + span.push_back(((v1 * padded_stride + v2) * data_bits) % (32 * 32 / data_bits)); + } + } + low_dim_iter_space[i] = span; + } + } + stride = stride * buffer->shape[k + 1] + min_conflict_pad; + reverse_strides.push_back(stride); + } + // Step 3. create the new padded buffer + ObjectPtr b = make_object(*buffer.get()); + Array strides; + for (int i = static_cast(reverse_strides.size()) - 1; i >= 0; i--) { + strides.push_back(reverse_strides[i]); + } + strides.push_back(1); + b->strides = strides; + Buffer new_buffer(b); + result.push_back(new_buffer); + padded_buffer_map_.Set(buffer, new_buffer); + } else { + result.push_back(buffer); + } + } + return result; + } + + /** + * \brief Replace all occurrence of the old buffer with the new buffer in the stmt + * \param stmt the stmt to do replacement + * \return the stmt after replacement + */ + Stmt RewriteBufferAccess(const Stmt& stmt) { + class Rewriter : public StmtExprMutator { + public: + explicit Rewriter(const Map& buffer_map) : buffer_map_(buffer_map) {} + + private: + PrimExpr VisitExpr_(const BufferLoadNode* _op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_op)); + BufferLoadNode* op = load.CopyOnWrite(); + if (buffer_map_.count(op->buffer)) { + op->buffer = buffer_map_[op->buffer]; + } + return std::move(load); + } + + Stmt VisitStmt_(const BufferStoreNode* _op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_op)); + BufferStoreNode* op = store.CopyOnWrite(); + if (buffer_map_.count(op->buffer)) { + op->buffer = buffer_map_[op->buffer]; + } + return std::move(store); + } + + Stmt VisitStmt_(const BlockNode* op) final { + // To reduce the number of blocks in block sref reuse map, we check whether the block is + // really mutated (i.e., the old buffer appears in the block). If so, we return the block + // after mutation. Otherwise we just return the original block. + bool changed = false; + // Step 1. Mutate the read region. + Array reads; + for (const BufferRegion& read : op->reads) { + if (buffer_map_.count(read->buffer)) { + changed = true; + reads.push_back(BufferRegion(buffer_map_[read->buffer], read->region)); + } else { + reads.push_back(read); + } + } + // Step 2. Mutate the write region. + Array writes; + for (const BufferRegion& write : op->writes) { + if (buffer_map_.count(write->buffer)) { + changed = true; + writes.push_back(BufferRegion(buffer_map_[write->buffer], write->region)); + } else { + writes.push_back(write); + } + } + // Step 4. Mutate `match_buffers`. If an old buffer appears as a source of + // MatchBufferRegion, the storage scope of the target buffer also needs to be set. + Array match_buffers; + for (const MatchBufferRegion& match_buffer : op->match_buffers) { + if (buffer_map_.count(match_buffer->source->buffer)) { + changed = true; + Buffer new_buffer = buffer_map_[match_buffer->source->buffer]; + match_buffers.push_back(MatchBufferRegion( + match_buffer->buffer, BufferRegion(new_buffer, match_buffer->source->region))); + } else { + match_buffers.push_back(match_buffer); + } + } + // Step 5. Recursively mutate the block. + Stmt res = StmtMutator::VisitStmt_(op); + if (res.get() != op) { + changed = true; + } + + if (changed) { + ObjectPtr block = CopyOnWrite(res.as()); + block->reads = std::move(reads); + block->writes = std::move(writes); + block->match_buffers = std::move(match_buffers); + return Stmt(block); + } else { + return GetRef(op); + } + } + const Map& buffer_map_; + }; + Rewriter rewriter(padded_buffer_map_); + return rewriter(stmt); + } + + /** + * \brief an equivalent of scale * loop_var with loop_var: {min=0, extent=extent} + */ + struct Pattern { + int extent; + int scale; + }; + + /** + * \brief Collect pattern from indices + */ + class PatternCollector : public StmtExprVisitor { + void VisitExpr_(const VarNode* op) final { + if (!success_) { + return; + } + int extent = var_range_[GetRef(op)]->extent.as()->value; + if (extent > 1) { + stack_.push({{extent, 1}}); + } else { + stack_.push({}); + } + } + + void VisitExpr_(const AddNode* op) final { + ExprVisitor::VisitExpr_(op); + if (!success_) { + return; + } + std::vector merged_patterns; + std::vector r = stack_.top(); + stack_.pop(); + std::vector l = stack_.top(); + stack_.pop(); + for (const Pattern& pattern : l) { + merged_patterns.push_back(pattern); + } + for (const Pattern& pattern : r) { + merged_patterns.push_back(pattern); + } + if (merged_patterns.empty()) { + stack_.push({}); + return; + } + std::vector ret; + ret.push_back(merged_patterns[0]); + for (int i = 0; i < static_cast(merged_patterns.size()); i++) { + Pattern prev_pattern = ret.back(); + if (merged_patterns[i].extent * merged_patterns[i].scale == prev_pattern.scale) { + ret.pop_back(); + ret.push_back( + {prev_pattern.extent * merged_patterns[i].extent, merged_patterns[i].scale}); + } + } + stack_.push(ret); + } + + void VisitExpr_(const FloorDivNode* op) final { + ExprVisitor::VisitExpr_(op); + if (!success_) { + return; + } + std::vector inner = stack_.top(); + stack_.pop(); + int lower_factor = op->b.as()->value; + std::vector ret; + for (const Pattern& pattern : inner) { + if (pattern.scale >= lower_factor) { + if (pattern.scale % lower_factor == 0) { + ret.push_back({pattern.extent, pattern.scale / lower_factor}); + } else { + success_ = false; + } + } else if (pattern.scale * pattern.extent > lower_factor) { + if ((pattern.scale * pattern.extent) % lower_factor == 0) { + ret.push_back({pattern.extent * pattern.scale / lower_factor, 1}); + } else { + success_ = false; + } + } + } + stack_.push(ret); + } + + void VisitExpr_(const FloorModNode* op) final { + ExprVisitor::VisitExpr_(op); + if (!success_) { + return; + } + std::vector inner = stack_.top(); + stack_.pop(); + int extent = op->b.as()->value; + std::vector ret; + for (const Pattern& pattern : inner) { + if (pattern.scale < extent) { + if (extent % pattern.scale == 0) { + if (extent / pattern.scale < pattern.extent) { + ret.push_back({extent / pattern.scale, pattern.scale}); + } else { + ret.push_back({pattern.extent, pattern.scale}); + } + } else { + success_ = false; + } + } + } + stack_.push(ret); + } + + void VisitExpr_(const MulNode* op) final { + ExprVisitor::VisitExpr_(op); + if (!success_) { + return; + } + std::vector inner = stack_.top(); + stack_.pop(); + int scale = op->b.as()->value; + std::vector ret; + for (const Pattern& pattern : inner) { + ret.push_back({pattern.extent, pattern.scale * scale}); + } + stack_.push(ret); + } + + public: + explicit PatternCollector(const Map& var_range) : var_range_(var_range) {} + + /*! + * \brief Collect the iteration space for given indices. The iteration space is the possible + * values that an index can take (do not remove duplicate). + * For example, the input is [ty, tx*4], where tx is in [0, 16), and ty is in [0, 2). + * The output would be {{0, 1}, {0, 4, ..., 60}} + * \param indices The indices to analyze + * \param var_range The range of loop variables + * \param data_bits The size of dtype in bits + * \return The iteration space. The first array represents dimensions, and the second array + * represents the iteration space of one dimension + */ + static std::vector> CollectIterationSpace(const Array& indices, + const Map& var_range, + int data_bits) { + PatternCollector collector(var_range); + std::vector> ret; + for (int i = 0; i < static_cast(indices.size()); i++) { + collector(indices[i]); + if (collector.success_ && collector.stack_.size() == 1) { + auto patterns = collector.stack_.top(); + int extent_prod = 1; + for (const Pattern& p : patterns) { + extent_prod *= p.extent; + } + std::vector iter_space; + for (int thread_id = 0; thread_id < extent_prod; thread_id++) { + int index = 0; + int n = thread_id; + for (int j = static_cast(patterns.size()) - 1; j >= 0; j--) { + int val = n % patterns[j].extent; + index += val * patterns[j].scale; + n /= patterns[j].extent; + } + iter_space.push_back(index); + } + + ret.push_back(iter_space); + collector.stack_.pop(); + } else { + ret.push_back({}); + } + } + return ret; + } + + std::stack> stack_; + const Map& var_range_; + bool success_ = true; + }; + + /*! A utility class for calling CollectIterationSpace to each buffer access*/ + class IterSpaceAnalyzer : public StmtExprVisitor { + public: + IterSpaceAnalyzer(const Map& substitute_map, AutoPadder* self, int data_bits, + const Map warp_thread_extent) + : substitute_map_(substitute_map), + self(self), + data_bits_(data_bits), + warp_thread_extent_(warp_thread_extent) {} + + private: + bool CheckVarContiguous(PrimExpr e, Var var, const Map& subst_map) { + PrimExpr e1 = Substitute(e, [var](const Var& v) -> Optional { + if (v.same_as(var)) { + return Integer(0); + } else { + return v; + } + }); + PrimExpr e2 = Substitute(e, [var](const Var& v) -> Optional { + if (v.same_as(var)) { + return Integer(1); + } else { + return v; + } + }); + arith::Analyzer analyzer; + return !analyzer.CanProve(Substitute(e2 - e1, subst_map) != 1); + } + + void VisitStmt_(const ForNode* op) final { + if (op->kind != ForKind::kThreadBinding) { + substitute_map_.Set(op->loop_var, op->min); + } else { + Integer extent = + warp_thread_extent_.Get(op->thread_binding.value()->thread_tag).value_or(1); + var_range_.Set(op->loop_var, Range::FromMinExtent(op->min, extent)); + } + if (op->kind == ForKind::kVectorized) { + vector_var = op->loop_var; + vector_length_ = op->extent.as()->value; + } + StmtExprVisitor::VisitStmt_(op); + if (op->kind == ForKind::kVectorized) { + vector_length_ = -1; + } + if (op->kind != ForKind::kThreadBinding) { + substitute_map_.erase(op->loop_var); + } + } + /*! + * \brief Take a typical warp and collect the iteration space for buffer store + * For example, the access is A[outer*2+ty, tx*4+vec] = xxx, where tx is threadIdx.x, and ty is + * threadIdx.y. tx is in [0, 16), and ty is in [0, 2). + * The iteration space would be {{0, 1}, {0, 4, ..., 60}}. + * \param op the buffer store + */ + void VisitStmt_(const BufferStoreNode* op) final { + runtime::StorageScope scope = runtime::StorageScope::Create(op->buffer.scope()); + if (scope.rank == runtime::StorageRank::kShared) { + Array substitued_indices; + arith::Analyzer analyzer; + for (const PrimExpr& e : op->indices) { + substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); + } + std::vector> iter_space = + PatternCollector::CollectIterationSpace(substitued_indices, var_range_, data_bits_); + if (!iter_space.empty()) { + self->iter_spaces_[op->buffer.get()].push_back(iter_space); + } + if (vector_length_ != -1 && + CheckVarContiguous(op->indices.back(), vector_var, substitute_map_)) { + Integer m = self->padding_min_.Get(op->buffer).value_or(1); + self->padding_min_.Set(op->buffer, Downcast(max(vector_length_, m))); + } + } + StmtExprVisitor::VisitStmt_(op); + } + /*! + * \brief Take a typical warp and collect the iteration space for buffer load + * For example, the access is xxx = A[outer*2+ty, tx*4+vec], where tx is threadIdx.x, and ty is + * threadIdx.y. tx is in [0, 16), and ty is in [0, 2). + * The iteration space would be {{0, 1}, {0, 4, ..., 60}}. + * \param op the buffer load + */ + void VisitExpr_(const BufferLoadNode* op) final { + runtime::StorageScope scope = runtime::StorageScope::Create(op->buffer.scope()); + if (scope.rank == runtime::StorageRank::kShared) { + Array substitued_indices; + arith::Analyzer analyzer; + for (const PrimExpr& e : op->indices) { + substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); + } + std::vector> iter_space = + PatternCollector::CollectIterationSpace(substitued_indices, var_range_, data_bits_); + if (!iter_space.empty()) { + self->iter_spaces_[op->buffer.get()].push_back(iter_space); + } + if (vector_length_ != -1 && + CheckVarContiguous(substitued_indices.back(), vector_var, substitute_map_)) { + Integer m = self->padding_min_.Get(op->buffer).value_or(1); + self->padding_min_.Set(op->buffer, Downcast(max(vector_length_, m))); + } + } + StmtExprVisitor::VisitExpr_(op); + } + + /*! + * \brief Take a typical warp and collect the iteration space for load_matrix_sync and + * store_matrix_sync + * For example, the access region is A[y*16+16, x*16+16], where y and x are not bound to + * threadIdx. The iteration space would be {{0, 1, ..., 15}, {0, 1, ..., 15}}. + * \param op the call node + */ + void VisitStmt_(const BlockNode* op) final { + if (const auto* eval = op->body.as()) { + if (const auto* call = eval->value.as()) { + if (call->op == builtin::tvm_load_matrix_sync() || + call->op == builtin::tvm_store_matrix_sync()) { + for (const MatchBufferRegion& r : op->match_buffers) { + Buffer src_buffer = r->source->buffer; + runtime::StorageScope scope = runtime::StorageScope::Create(src_buffer.scope()); + if (scope.rank == runtime::StorageRank::kShared) { + Region region = r->source->region; + Array indices; + for (int i = 0; i < static_cast(region.size()); i++) { + Var var("region" + std::to_string(i)); + indices.push_back(region[i]->min + var); + var_range_.Set(var, Range::FromMinExtent(0, region[i]->extent)); + } + Array substitued_indices; + arith::Analyzer analyzer; + for (const PrimExpr& e : indices) { + substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); + } + std::vector> iter_space = PatternCollector::CollectIterationSpace( + substitued_indices, var_range_, data_bits_); + if (!iter_space.empty()) { + self->iter_spaces_[src_buffer.get()].push_back(iter_space); + } + } + } + } + } + } + } + + Map substitute_map_; + AutoPadder* self; + int data_bits_; + Map warp_thread_extent_; + Map var_range_; + int vector_length_ = -1; + Var vector_var; + }; + + /*! + * \brief Analyze the shared memory access + * \param stmt The data copy + * \param outer_loops The outer loops of the stmt + * \param data_bits The length of dtype in bits + * \param thread_extent The extents of all thread binding loops + */ + void AnalyzeSharedMemoryAccess(const Stmt& stmt, const Array& outer_loops, int data_bits, + const Map& thread_extent) { + Map warp_thread_extent; + Integer prod = 1; + Array thread_tags{"threadIdx.x", "threadIdx.y", "threadIdx.z"}; + arith::Analyzer analyzer; + for (int i = 0; i < 3; i++) { + Integer extent = thread_extent.Get(thread_tags[i]).value_or(1); + if (analyzer.CanProve(prod * extent >= 32)) { + warp_thread_extent.Set(thread_tags[i], Downcast(floordiv(32, prod))); + prod *= floordiv(32, prod); + break; + } else { + warp_thread_extent.Set(thread_tags[i], Downcast(extent)); + prod *= extent; + } + } + Map substitute_map; + for (const For& loop : outer_loops) { + substitute_map.Set(loop->loop_var, loop->min); + } + IterSpaceAnalyzer iter_space_analyzer(substitute_map, this, data_bits, warp_thread_extent); + iter_space_analyzer(stmt); + } + + private: + /*! \brief A map from the old buffers to the new padded buffers */ + Map padded_buffer_map_; + /*! \brief A map from each buffer to the iteration spaces of the accesses*/ + std::unordered_map>>> iter_spaces_; + /*! \brief A map from each buffer to their minimal padding size */ + Map padding_min_; + /*! \brief max padding size in relative to the original shape*/ + const double max_pad_factor_ = 0.25; + + friend class AutoCopyMutator; +}; + +class AutoCopyMutator : public StmtExprMutator { + public: + explicit AutoCopyMutator(Map thread_extent) : thread_extent_(thread_extent) {} + /** + * \brief Replace old buffers with padded buffers in the stmt + * \param stmt The stmt to rewrite + * \return The stmt after rewrite + */ + Stmt RewritePaddingBody(const Stmt& stmt) { return padder.RewriteBufferAccess(stmt); } + + private: + Stmt VisitStmt_(const BlockNode* op) final { + Block block = Downcast(StmtMutator::VisitStmt_(op)); + // only rewrite the block annotated with "auto_copy" + if (GetAnn(op, tir::attr::auto_copy).value_or(0)->value == 0) { + BlockNode* n = block.CopyOnWrite(); + n->alloc_buffers = padder.PadSharedMemory(std::move(n->alloc_buffers)); + return std::move(block); + } + ICHECK_EQ(block->writes.size(), 1); + ICHECK_GE(block->reads.size(), 1); + + BufferRegion target_read = block->reads[0]; + if (block->reads.size() > 1) { + bool found = false; + for (size_t i = 0; i < block->reads.size(); i++) { + if (block->reads[i]->buffer.scope() == "wmma.accumulator") { + found = true; + target_read = block->reads[i]; + } + } + ICHECK(found) << "Multiple buffer read"; + } + + int data_bits = target_read->buffer->dtype.bits(); + ConstraintSet constraints(this->thread_extent_, // + this->outer_loops_, // + target_read, // + block->writes[0], // + data_bits, // + block->annotations); + BlockNode* n = block.CopyOnWrite(); + OutputSet outputs; + for (RewriteRule* rule : rules) { + n->body = rule->Apply(std::move(n->body), constraints, &outputs); + } + for (const Buffer& buffer : outputs.alloc_buffer) { + n->alloc_buffers.push_back(buffer); + } + for (const auto& p : outputs.padding_min) { + Integer m = padder.padding_min_.Get(p.first).value_or(1); + padder.padding_min_.Set(p.first, Downcast(max(p.second, m))); + } + padder.AnalyzeSharedMemoryAccess(block->body, outer_loops_, data_bits, thread_extent_); + n->alloc_buffers = padder.PadSharedMemory(std::move(n->alloc_buffers)); + return std::move(block); + } + + Stmt VisitStmt_(const ForNode* op) final { + outer_loops_.push_back(GetRef(op)); + Stmt stmt = StmtMutator::VisitStmt_(op); + outer_loops_.pop_back(); + return stmt; + } + + /*! \brief Thread extents collected. */ + Map thread_extent_; + /*! \brief The outer loops during recursive visit */ + Array outer_loops_; + /*! \brief Calculating optimal padding size */ + AutoPadder padder; + + /*! \brief All rewrite rules. */ + const std::array rules = {&inverse_mapping, // + &coalesced_access, // + &create_local_stage, // + &shared_to_wmma, // + &wmma_to_global, // + &wmma_to_shared}; +}; + +/*! + * \brief Collect the extent for all thread binding loops. + */ +class ThreadExtentCollector : public StmtVisitor { + public: + static Map CollectThreadExtent(const Stmt& stmt) { + ThreadExtentCollector collector; + collector(stmt); + return collector.thread_extent_; + } + + private: + void VisitStmt_(const BlockNode* op) final { + if (Optional warp_execution = GetAnn(op, "warp_execution")) { + if (warp_execution.value()->value != 0) { + thread_extent_.Set("threadIdx.x", Integer(32)); + } + } + StmtVisitor::VisitStmt_(op); + } + void VisitStmt_(const ForNode* op) final { + if (op->thread_binding.defined() && op->thread_binding.value()->iter_type == kThreadIndex) { + thread_extent_.Set(op->thread_binding.value()->thread_tag, Downcast(op->extent)); + } + StmtVisitor::VisitStmt_(op); + } + + /*! \brief the map from thread tag to its extent */ + Map thread_extent_; +}; + +namespace transform { + +Pass LowerAutoCopy() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + AutoCopyMutator mutator(ThreadExtentCollector::CollectThreadExtent(n->body)); + n->body = mutator(std::move(n->body)); + n->body = mutator.RewritePaddingBody(n->body); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerAutoCopy", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerAutoCopy").set_body_typed(LowerAutoCopy); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/memhammer_rewrite_rule.h b/src/tir/transforms/memhammer_rewrite_rule.h new file mode 100644 index 000000000000..a43cdabb21e5 --- /dev/null +++ b/src/tir/transforms/memhammer_rewrite_rule.h @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_TRANSFORMS_MEMHAMMER_REWRITE_RULE_H_ +#define TVM_TIR_TRANSFORMS_MEMHAMMER_REWRITE_RULE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../schedule/utils.h" + +namespace tvm { +namespace tir { + +/*! \brief The set containing all possible constraints of a data copy */ +struct ConstraintSet { + /*! \brief The extents of the thread binding loops */ + Map thread_extent; + /*! \brief The outer loops surrounding the data copy */ + Array outer_loops; + /*! \brief The read region of the data copy */ + BufferRegion read_region; + /*! \brief The write region of the data copy */ + BufferRegion write_region; + /*! \brief The dtype size in bits */ + int data_bits; + /*! \brief Whether to insert a local stage in the data copy */ + int add_local_stage = 0; + /*! \brief The vectorization length in bytes */ + int vector_bytes = 1; + + explicit ConstraintSet(Map thread_extent, // + Array outer_loops, // + BufferRegion read_region, // + BufferRegion write_region, // + int data_bits, // + const Map& ann) + : thread_extent(thread_extent), + outer_loops(outer_loops), + read_region(read_region), + write_region(write_region), + data_bits(data_bits) { + if (Optional add_local_stage = ann.Get("local_stage")) { + this->add_local_stage = Downcast(add_local_stage.value())->value; + } + if (Optional vector_bytes = ann.Get("vector_bytes")) { + this->vector_bytes = Downcast(vector_bytes.value())->value; + } + } +}; + +/*! \brief The set containing all possible outputs of a rewrite rule */ +struct OutputSet { + /*! \brief New buffers allocated after rewrite */ + Array alloc_buffer; + /*! \brief The minimal padding size of a buffer in base 2 logarithm */ + Map padding_min; +}; + +/*! + * \brief Rules to rewrite a data copy. + */ +class RewriteRule { + protected: + /* RewriteRule() = default; */ + /*! + * \brief Rewrite the stmt under certain constraints + * \param stmt The stmt + * \param constraints The constraints of the rewrite + * \param output Some additional information that the rewrite rule produces. (including the new + * buffer to be allocated, etc.) + * \return the stmt after rewrite + */ + virtual Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const = 0; + /*! + * \brief Whether the rewrite rule can be applied to the stmt under certain constraints + * \param stmt The stmt + * \param constraints The constraints of the rewrite + * \return A boolean flag indicating whether the rule can be applied + */ + virtual bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const { return true; } + + public: + inline Stmt Apply(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const { + if (CanApply(stmt, constraints)) { + return Rewrite(stmt, constraints, output); + } else { + return stmt; + } + } +}; + +inline bool IsCopyBetweenScope(const Buffer& src_buffer, const Buffer& tgt_buffer, + runtime::StorageRank src_rank, runtime::StorageRank tgt_rank) { + runtime::StorageScope src_scope = runtime::StorageScope::Create(src_buffer.scope()); + runtime::StorageScope tgt_scope = runtime::StorageScope::Create(tgt_buffer.scope()); + return src_scope.rank == src_rank && tgt_scope.rank == tgt_rank; +} + +inline bool IsScope(const Buffer& src_buffer, runtime::StorageRank src_rank) { + runtime::StorageScope src_scope = runtime::StorageScope::Create(src_buffer.scope()); + return src_scope.rank == src_rank; +} + +/*! + * \brief Coalesce and vectorize memory access. + */ +class CoalescedAccess : public RewriteRule { + public: + CoalescedAccess() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kGlobal, + runtime::StorageRank::kShared) || + IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kGlobal); + } +}; + +/*! + * \brief Transform from A[f(i,j)] = B[i,j] to A[i,j] = B[f^{-1}(i,j)] + */ +class InverseMapping : public RewriteRule { + public: + InverseMapping() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kGlobal); + } +}; + +/*! + * \brief Create a local stage when loading from global memory to shared memory. + */ +class CreateLocalStage : public RewriteRule { + public: + CreateLocalStage() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kGlobal, + runtime::StorageRank::kShared) && + is_one(constraints.add_local_stage); + } +}; + +/*! + * \brief Add a cache stage in shared memory. Perform tensor core rewrite for wmma->shared, and + * perform coalescing and vectorizing for shared->global. + */ +class WmmaToGlobal : public RewriteRule { + public: + WmmaToGlobal() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kWMMAAccumulator, + runtime::StorageRank::kGlobal); + } +}; + +/*! + * \brief Rewrite shared->wmma data copy with load_matrix_sync + */ +class SharedToWmma : public RewriteRule { + public: + SharedToWmma() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kWMMAMatrixA) || + IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kWMMAMatrixB); + } +}; + +/*! + * \brief Rewrite wmma->shared data copy with store_matrix_sync + */ +class WmmaToShared : public RewriteRule { + public: + WmmaToShared() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kWMMAAccumulator, + runtime::StorageRank::kShared); + } +}; + +/*! + * \brief Insert a cache stage to the compute location + * \param stmt the stmt + * \param is_write_cache whether to write a read cache or write cache + * \param storage_scope the storage scope of the new cache + * \param compute_location the compute location. + * \param outer_loops the outer loops of this stmt + * \param alloc_buffer the new cache block + * \return a pair. The first is the stmt after transformation. + * The second is the SeqStmt that contains 2 stages (one original and another inserted). + */ +std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope, + Optional compute_location, + const Array& outer_loops, Buffer* alloc_buffer); + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_TRANSFORMS_MEMHAMMER_REWRITE_RULE_H_ diff --git a/src/tir/transforms/memhammer_tensorcore_rewrite.cc b/src/tir/transforms/memhammer_tensorcore_rewrite.cc new file mode 100644 index 000000000000..86202de3970c --- /dev/null +++ b/src/tir/transforms/memhammer_tensorcore_rewrite.cc @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./memhammer_rewrite_rule.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Tile the 2 innermost loops to extent=16. This helps further tensor core rewrite. + * \param stmt The stmt + * \return A pair. The first is the stmt after transformation. + * The second is the compute location where we may add write cache. + */ +std::pair> TileWmmaBlock(Stmt stmt) { + Stmt body = stmt; + std::vector loops; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + int n = loops.size(); + PrimExpr extent_last1 = loops[n - 1]->extent; + PrimExpr extent_last2 = loops[n - 2]->extent; + { + arith::Analyzer analyzer; + if (!analyzer.CanProveEqual(floormod(extent_last1, 16), 0) || + !analyzer.CanProveEqual(floormod(extent_last2, 16), 0)) { + return std::make_pair(stmt, NullOpt); + } + } + Var new_loop_vars[4] = { + /*0:*/ loops[n - 2]->loop_var.copy_with_suffix("_0"), + /*1:*/ loops[n - 1]->loop_var.copy_with_suffix("_0"), + /*2:*/ loops[n - 2]->loop_var.copy_with_suffix("_1"), + /*3:*/ loops[n - 1]->loop_var.copy_with_suffix("_1"), + }; + body = Substitute(std::move(body), + Map{ + {loops[n - 2]->loop_var, new_loop_vars[0] * 16 + new_loop_vars[2]}, + {loops[n - 1]->loop_var, new_loop_vars[1] * 16 + new_loop_vars[3]}, + }); + { + PrimExpr factor[4] = { + /*0:*/ floordiv(extent_last2, 16), // + /*1:*/ floordiv(extent_last1, 16), // + /*3:*/ 16, // + /*4:*/ 16, // + }; + body = For(new_loop_vars[3], 0, factor[3], ForKind::kSerial, std::move(body)); + body = For(new_loop_vars[2], 0, factor[2], ForKind::kSerial, std::move(body)); + body = For(new_loop_vars[1], 0, factor[1], ForKind::kSerial, std::move(body)); + body = For(new_loop_vars[0], 0, factor[0], ForKind::kSerial, std::move(body)); + } + For compute_location = Downcast(body); + for (int i = n - 3; i >= 0; i--) { + body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, std::move(body), + loops[i]->thread_binding, loops[i]->annotations); + } + return {body, compute_location}; +} + +Array RelaxIndices(const Array& indices, const Array& shape, + const Map& var_dom) { + Array int_set; + int_set.reserve(indices.size()); + for (auto& indice : indices) { + int_set.push_back(arith::EvalSet(indice, var_dom)); + } + int ndim = int_set.size(); + Array region; + region.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + region.push_back(int_set[i].CoverRange(Range::FromMinExtent(0, shape[i]))); + } + return region; +} + +/*! + * \brief Rewrite the data copy that stores to wmma fragment with wmma::load_matrix_sync + * \param stmt The stmt to rewrite + * \return The stmt after rewrite + */ +Stmt RewriteWmmaLoad(Stmt stmt) { + using arith::IntSet; + const DataType dtype = DataType::Float(16); + const DataType int32 = DataType::Int(32); + + Stmt body = stmt; + std::vector loops; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + int n = loops.size(); + + Map var_dom{ + {loops[n - 1]->loop_var, IntSet::FromMinExtent(loops[n - 1]->min, loops[n - 1]->extent)}, + {loops[n - 2]->loop_var, IntSet::FromMinExtent(loops[n - 2]->min, loops[n - 2]->extent)}, + }; + // TODO(tian): the assumption that the RHS of BufferStore is BufferLoad may not be accurate + const BufferStoreNode* buf_store = TVM_TYPE_AS(body, BufferStoreNode); + const BufferLoadNode* buf_load = TVM_TYPE_AS(buf_store->value, BufferLoadNode); + + Buffer src_buffer = buf_load->buffer; + Buffer tgt_buffer = buf_store->buffer; + std::string layout = tgt_buffer.scope() == "wmma.matrix_a" ? "row_major" : "col_major"; + Buffer new_src_buffer( + /*data=*/Var("src", PointerType(PrimType(dtype), src_buffer.scope())), + /*dtype=*/dtype, + /*shape=*/{Integer(16), Integer(16)}, + /*strides=*/{Var("s1", int32), Var("s0", int32)}, + /*elem_offset=*/Var("src_elem_offset", int32), + /*name=*/"src", + /*data_alignment=*/64, + /*offset_factor=*/16, + /*buffer_type=*/kDefault); + Buffer new_tgt_buffer( + /*data=*/Var("tgt", PointerType(PrimType(dtype), tgt_buffer.scope())), + /*dtype=*/dtype, + /*shape=*/{Integer(16), Integer(16)}, + /*strides=*/{}, + /*elem_offset=*/Var("tgt_elem_offset", int32), + /*name=*/"tgt", + /*data_alignment=*/64, + /*offset_factor=*/16, + /*buffer_type=*/kDefault); + Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); + Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); + Stmt wmma_body = BlockRealize( + /*iter_values=*/{}, + /*predicate=*/Bool(true), + Block( + /*iter_vars=*/{}, + /*reads=*/{BufferRegion(src_buffer, read_region)}, + /*writes=*/{BufferRegion(tgt_buffer, write_region)}, + /*name_hint=*/"wmma_load", + /*body=*/ + Evaluate(Call( + /*data=*/runtime::DataType::Handle(), + /*op=*/builtin::tvm_load_matrix_sync(), + { + /*0:*/ new_tgt_buffer->data, + /*1:*/ 16, + /*2:*/ 16, + /*3:*/ 16, + /*4:*/ floordiv(new_tgt_buffer->elem_offset, 256) + + floordiv(floormod(new_tgt_buffer->elem_offset, 256), 16), + /*5:*/ + Call( + /*dtype=*/runtime::DataType::Handle(), + /*op=*/builtin::tvm_access_ptr(), + /*args=*/ + { + /*0:*/ TypeAnnotation(new_src_buffer->dtype), + /*1:*/ new_src_buffer->data, + /*2:*/ new_src_buffer->elem_offset, + /*3:*/ new_src_buffer->strides[new_src_buffer->strides.size() - 2] * 16, + /*4:*/ 1, + }), + /*6:*/ new_src_buffer->strides[new_src_buffer->strides.size() - 2], + /*7:*/ StringImm(layout), + })), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/ + { + /*0:*/ MatchBufferRegion(new_src_buffer, BufferRegion(src_buffer, read_region)), + /*1:*/ MatchBufferRegion(new_tgt_buffer, BufferRegion(tgt_buffer, write_region)), + }, + /*annotations=*/{})); + for (int i = n - 3; i >= 0; i--) { + wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, + std::move(wmma_body), loops[i]->thread_binding, loops[i]->annotations); + } + return wmma_body; +} + +/*! + * \brief Rewrite the data copy that loads from wmma fragment with wmma::store_matrix_sync + * \param stmt The stmt to rewrite + * \return The stmt after rewrite + */ +Stmt RewriteWmmaStore(Stmt stmt) { + using arith::IntSet; + const DataType int32 = DataType::Int(32); + + Stmt body = stmt; + std::vector loops; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + int n = loops.size(); + + Map var_dom{ + {loops[n - 1]->loop_var, IntSet::FromMinExtent(loops[n - 1]->min, loops[n - 1]->extent)}, + {loops[n - 2]->loop_var, IntSet::FromMinExtent(loops[n - 2]->min, loops[n - 2]->extent)}, + }; + // TODO(tian): the assumption that the RHS of BufferStore is BufferLoad may not be accurate + const BufferStoreNode* buf_store = TVM_TYPE_AS(body, BufferStoreNode); + const BufferLoadNode* buf_load = nullptr; + PostOrderVisit(buf_store->value, [&](const ObjectRef& obj) { + const BufferLoadNode* load = obj.as(); + if (load && load->buffer.scope() == "wmma.accumulator") { + ICHECK(buf_load == nullptr || buf_load->buffer.same_as(load->buffer)) + << "More than one source buffer of wmma accumulator found"; + buf_load = load; + } + return true; + }); + Buffer src_buffer = buf_load->buffer; + Buffer tgt_buffer = buf_store->buffer; + + const DataType dtype = src_buffer->dtype; + + Buffer new_src_buffer(/*data=*/Var("src", PointerType(PrimType(dtype), src_buffer.scope())), + /*dtype=*/dtype, + /*shape=*/{Integer(16), Integer(16)}, + /*strides=*/{}, + /*elem_offset=*/Var("src_elem_offset", int32), + /*name=*/"src", + /*data_alignment=*/64, + /*offset_factor=*/16, + /*buffer_type=*/kDefault); + Buffer new_tgt_buffer(/*data=*/Var("tgt", PointerType(PrimType(dtype), tgt_buffer.scope())), + /*dtype=*/dtype, + /*shape=*/{Integer(16), Integer(16)}, + /*strides=*/{Var("s1", int32), Var("s0", int32)}, + /*elem_offset=*/Var("tgt_elem_offset", int32), + /*name=*/"tgt", + /*data_alignment=*/64, + /*offset_factor=*/16, + /*buffer_type=*/kDefault); + + Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); + Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); + Stmt wmma_body = BlockRealize( + /*iter_values=*/{}, // + /*predicate=*/Bool(true), + Block(/*iter_vars=*/{}, + /*reads=*/{BufferRegion(src_buffer, read_region)}, + /*writes=*/{BufferRegion(tgt_buffer, write_region)}, + /*name_hint=*/"wmma_store", + Evaluate(Call( + /*data=*/runtime::DataType::Handle(), + /*op=*/builtin::tvm_store_matrix_sync(), + {/*0:*/ new_src_buffer->data, + /*1:*/ 16, + /*2:*/ 16, + /*3:*/ 16, + /*4:*/ floordiv(new_src_buffer->elem_offset, 256) + + floordiv(floormod(new_src_buffer->elem_offset, 256), 16), + /*5:*/ + Call( + /*data=*/runtime::DataType::Handle(), + /*op=*/builtin::tvm_access_ptr(), + { + /*0:*/ TypeAnnotation(new_tgt_buffer->dtype), + /*1:*/ new_tgt_buffer->data, + /*2:*/ new_tgt_buffer->elem_offset, + /*3:*/ new_tgt_buffer->strides[0] * 16, + /*4:*/ 2, + }), + /*6:*/ new_tgt_buffer->strides[0], + /*7:*/ StringImm("row_major")})), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/ + { + MatchBufferRegion(new_src_buffer, BufferRegion(src_buffer, read_region)), + MatchBufferRegion(new_tgt_buffer, BufferRegion(tgt_buffer, write_region)), + }, + /*annotations=*/{})); + for (int i = n - 3; i >= 0; i--) { + wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, + std::move(wmma_body), loops[i]->thread_binding, loops[i]->annotations); + } + return wmma_body; +} + +Stmt SharedToWmma::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt after_tiling = TileWmmaBlock(stmt).first; + output->padding_min.Set(constraints.read_region->buffer, 8); + return RewriteWmmaLoad(after_tiling); +} + +Stmt WmmaToShared::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt after_tiling = TileWmmaBlock(stmt).first; + output->padding_min.Set(constraints.write_region->buffer, 8); + return RewriteWmmaStore(after_tiling); +} + +class WmmaToGlobalRewriter : public StmtExprMutator { + public: + WmmaToGlobalRewriter(const SeqStmtNode* tgt_stmt, const ConstraintSet& constraints) + : tgt_stmt_(tgt_stmt), constraints_(constraints) {} + + private: + Stmt VisitStmt_(const SeqStmtNode* op) final { + if (op == tgt_stmt_) { + ICHECK_EQ(op->seq.size(), 2); + Stmt wmma_to_shared = RewriteWmmaStore(op->seq[0]); + Stmt shared_to_global = CoalescedAccess().Rewrite(op->seq[1], constraints_, nullptr); + return SeqStmt({wmma_to_shared, shared_to_global}); + } else { + return StmtMutator::VisitStmt_(op); + } + } + + const SeqStmtNode* tgt_stmt_; + const ConstraintSet& constraints_; +}; + +Stmt WmmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt body{nullptr}; + Optional compute_location{nullptr}; + std::tie(body, compute_location) = TileWmmaBlock(stmt); + SeqStmt seq{nullptr}; + Buffer cache_buffer; + // Step 1. add a shared memory cache + std::tie(body, seq) = InsertCacheStage(std::move(body), true, "shared.dyn", compute_location, + constraints.outer_loops, &cache_buffer); + output->alloc_buffer.push_back(cache_buffer); + output->padding_min.Set(cache_buffer, 8); + // Step 2. do coalesced rewrite and tensor core rewrite respectively for 2 parts + WmmaToGlobalRewriter rewriter(seq.get(), constraints); + return rewriter(body); +} + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_memhammer_lower_auto_copy.py b/tests/python/unittest/test_tir_transform_memhammer_lower_auto_copy.py new file mode 100644 index 000000000000..27f48edae9f2 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_memhammer_lower_auto_copy.py @@ -0,0 +1,1062 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import te +from tvm.script import tir as T +import sys +import pytest + + +@tvm.script.ir_module +class Transpose: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([16, 128], dtype="float32", scope="shared.dyn") + with T.block("A_shared"): + T.block_attr({"auto_copy": 1}) + for ax0, ax1 in T.grid(128, 16): + A_shared_dyn[ax1, ax0] = A[ax0, ax1] + with T.block("B"): + T.block_attr({"auto_copy": 1}) + for ax1, ax0 in T.grid(16, 128): + B[ax1, ax0] = A_shared_dyn[ax1, ax0] + + +@tvm.script.ir_module +class GlobalToShared: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer( + [128, 128], dtype="float32", scope="shared.dyn" + ) + with T.block("A_shared"): + T.block_attr({"auto_copy": 1, "vector_bytes": 16}) + for ax0, ax1 in T.grid(128, 128): + A_shared_dyn[ax0, ax1] = A[bx * 128 + ax0, by * 128 + ax1] + with T.block("B"): + for ax0, ax1 in T.grid(128, 128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] + + +@tvm.script.ir_module +class SharedToGlobal: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer( + [128, 128], dtype="float32", scope="shared.dyn" + ) + with T.block("A_shared"): + for ax0, ax1 in T.grid(128, 128): + A_shared_dyn[ax1, ax0] = A[bx * 128 + ax0, by * 128 + ax1] + with T.block("B"): + T.block_attr({"auto_copy": 1, "vector_bytes": 16}) + for ax1, ax0 in T.grid(128, 128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax1, ax0] + + +@tvm.script.ir_module +class GlobalToSharedWithLocalStage: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer( + [128, 128], dtype="float32", scope="shared.dyn" + ) + with T.block("A_shared"): + T.block_attr( + {"auto_copy": 1, "vector_bytes": 16, "local_stage": True} + ) + for ax0, ax1 in T.grid(128, 128): + A_shared_dyn[ax0, ax1] = A[bx * 128 + ax0, by * 128 + ax1] + with T.block("B"): + for ax0, ax1 in T.grid(128, 128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] + + +@tvm.script.ir_module +class SharedToWmma: + @T.prim_func + def main() -> None: + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer( + [128, 128], dtype="float16", scope="shared.dyn" + ) + A_wmma = T.alloc_buffer( + [128, 128], dtype="float16", scope="wmma.matrix_a" + ) + with T.block("A_wmma"): + T.block_attr({"auto_copy": 1}) + for ax0, ax1 in T.grid(128, 128): + A_wmma[ax0, ax1] = A_shared_dyn[ax0, ax1] + + +@tvm.script.ir_module +class WmmaToShared: + @T.prim_func + def main() -> None: + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + C_accum = T.alloc_buffer( + [128, 128], dtype="float32", scope="wmma.accumulator" + ) + C_shared = T.alloc_buffer( + [128, 128], dtype="float32", scope="shared.dyn" + ) + with T.block("C_shared"): + T.block_attr({"auto_copy": 1}) + for ax0, ax1 in T.grid(128, 128): + C_shared[ax0, ax1] = C_accum[ax0, ax1] + + +@tvm.script.ir_module +class WmmaToGlobal: + @T.prim_func + def main(c: T.handle) -> None: + C = T.match_buffer(c, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + C_accum = T.alloc_buffer( + [128, 128], dtype="float32", scope="wmma.accumulator" + ) + with T.block("C_global"): + T.block_attr({"auto_copy": 1, "vector_bytes": 16}) + for ax0, ax1 in T.grid(128, 128): + C[bx * 128 + ax0, by * 128 + ax1] = C_accum[ax0, ax1] + + +@tvm.script.ir_module +class WmmaToGlobalWithFusion: + @T.prim_func + def main(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [1024]) + C = T.match_buffer(c, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + C_accum = T.alloc_buffer( + [128, 128], dtype="float32", scope="wmma.accumulator" + ) + with T.block("C_global"): + T.block_attr({"auto_copy": 1, "vector_bytes": 16}) + for ax0, ax1 in T.grid(128, 128): + C[bx * 128 + ax0, by * 128 + ax1] = ( + C_accum[ax0, ax1] + A[bx * 128 + ax0] + ) + + +@tvm.script.ir_module +class TransformedGlobalToShared: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer( + [128, 128], dtype="float32", strides=[128, 1], scope="shared.dyn" + ) + with T.block("A_shared"): + T.block_attr({"auto_copy": 1, "vector_bytes": 16}) + for outer in T.serial(16): + for ty_1 in T.thread_binding(8, thread="threadIdx.y"): + for tx in T.thread_binding(32, thread="threadIdx.x"): + for vec in T.vectorized(4): + A_shared_dyn[ + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) + // 128 + % 128, + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) + % 128, + ] = A[ + bx * 128 + + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) + // 128 + % 128, + by * 128 + + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) + % 128, + ] + with T.block("B"): + for ax0, ax1 in T.grid(128, 128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] + + +@tvm.script.ir_module +class TransformedSharedToGlobal: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer( + [128, 128], dtype="float32", strides=[129, 1], scope="shared.dyn" + ) + with T.block("A_shared"): + T.reads(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + T.writes(A_shared_dyn[0:128, 0:128]) + for ax0, ax1 in T.grid(128, 128): + A_shared_dyn[ax1, ax0] = A[bx * 128 + ax0, by * 128 + ax1] + with T.block("B"): + T.block_attr({"auto_copy": 1, "vector_bytes": 16}) + for outer in T.serial(16): + for ty_1 in T.thread_binding(8, thread="threadIdx.y"): + for tx in T.thread_binding(32, thread="threadIdx.x"): + for vec in T.vectorized(4): + B[ + bx * 128 + + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) + // 128 + % 128, + by * 128 + + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) + % 128, + ] = A_shared_dyn[ + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) + % 128, + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) + // 128 + % 128, + ] + + +@tvm.script.ir_module +class TransformedGlobalToSharedWithLocalStage: + @T.prim_func + def main(a: T.handle, b: T.handle): + A = T.match_buffer(a, (1024, 1024)) + B = T.match_buffer(b, (1024, 1024)) + with T.block("root"): + T.reads(A[0:1024, 0:1024]) + T.writes(B[0:1024, 0:1024]) + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(""): + T.reads(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + T.writes(B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + A_shared_dyn = T.alloc_buffer( + (128, 128), strides=(128, 1), scope="shared.dyn" + ) + with T.block("A_shared"): + T.reads(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + T.writes(A_shared_dyn[0:128, 0:128]) + T.block_attr( + {"auto_copy": 1, "local_stage": True, "vector_bytes": 16} + ) + A_shared_dyn_local = T.alloc_buffer((16, 4), scope="local") + for ax0_ax1_fused_1 in T.thread_binding(8, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding( + 32, thread="threadIdx.x" + ): + for ax0_ax1_fused_0_cache in range(16): + for ax0_ax1_fused_3_cache in T.vectorized(4): + A_shared_dyn_local[ + ax0_ax1_fused_0_cache + * 8 + * 32 + * 4 + // 128 + % 128 + // 8, + ax0_ax1_fused_3_cache % 128, + ] = A[ + bx * 128 + + ( + ( + ( + ax0_ax1_fused_0_cache * 8 + + ax0_ax1_fused_1 + ) + * 32 + + ax0_ax1_fused_2 + ) + * 4 + + ax0_ax1_fused_3_cache + ) + // 128 + % 128, + by * 128 + + ( + ( + ( + ax0_ax1_fused_0_cache * 8 + + ax0_ax1_fused_1 + ) + * 32 + + ax0_ax1_fused_2 + ) + * 4 + + ax0_ax1_fused_3_cache + ) + % 128, + ] + for ax0_ax1_fused_0 in range(16): + for ax0_ax1_fused_3 in T.vectorized(4): + A_shared_dyn[ + ( + ( + (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) + * 32 + + ax0_ax1_fused_2 + ) + * 4 + + ax0_ax1_fused_3 + ) + // 128 + % 128, + ( + ( + (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) + * 32 + + ax0_ax1_fused_2 + ) + * 4 + + ax0_ax1_fused_3 + ) + % 128, + ] = A_shared_dyn_local[ + ax0_ax1_fused_0 * 8 * 32 * 4 // 128 % 128 // 8, + ax0_ax1_fused_3 % 128, + ] + with T.block("B"): + T.reads(A_shared_dyn[0:128, 0:128]) + T.writes(B[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + for ax0 in range(128): + for ax1 in range(128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] + + +@tvm.script.ir_module +class TransformedSharedToWmma: + @T.prim_func + def main() -> None: + s0 = T.int32() + s1 = T.int32() + # body + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer( + [128, 128], dtype="float16", strides=[136, 1], scope="shared.dyn" + ) + A_wmma = T.alloc_buffer( + [128, 128], dtype="float16", scope="wmma.matrix_a" + ) + with T.block("C_shared"): + T.reads(A_shared_dyn[0:128, 0:128]) + T.writes(A_wmma[0:128, 0:128]) + T.block_attr({"auto_copy": 1}) + for ax00, ax10 in T.grid(8, 8): + with T.block("wmma_load"): + T.reads( + A_shared_dyn[ + ax00 * 16 : ax00 * 16 + 16, + ax10 * 16 : ax10 * 16 + 16, + ] + ) + T.writes( + A_wmma[ + ax00 * 16 : ax00 * 16 + 16, + ax10 * 16 : ax10 * 16 + 16, + ] + ) + src = T.match_buffer( + A_shared_dyn[ + ax00 * 16 : ax00 * 16 + 16, + ax10 * 16 : ax10 * 16 + 16, + ], + [16, 16], + dtype="float16", + strides=[s1, s0], + scope="shared.dyn", + offset_factor=16, + ) + tgt = T.match_buffer( + A_wmma[ + ax00 * 16 : ax00 * 16 + 16, + ax10 * 16 : ax10 * 16 + 16, + ], + [16, 16], + dtype="float16", + scope="wmma.matrix_a", + offset_factor=16, + ) + T.evaluate( + T.tvm_load_matrix_sync( + tgt.data, + 16, + 16, + 16, + tgt.elem_offset // 256 + + tgt.elem_offset % 256 // 16, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + src.data, + src.elem_offset, + s1 * 16, + 1, + dtype="handle", + ), + s1, + "row_major", + dtype="handle", + ) + ) + + +@tvm.script.ir_module +class TransformedWmmaToShared: + @T.prim_func + def main() -> None: + s0 = T.int32() + s1 = T.int32() + # body + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + C_accum = T.alloc_buffer( + [128, 128], dtype="float32", scope="wmma.accumulator" + ) + C_shared = T.alloc_buffer( + [128, 128], dtype="float32", strides=[136, 1], scope="shared.dyn" + ) + with T.block("A_wmma"): + T.reads(C_accum[0:128, 0:128]) + T.writes(C_shared[0:128, 0:128]) + T.block_attr({"auto_copy": 1}) + for ax00, ax10 in T.grid(8, 8): + with T.block("wmma_store"): + T.reads( + C_accum[ + ax00 * 16 : ax00 * 16 + 16, + ax10 * 16 : ax10 * 16 + 16, + ] + ) + T.writes( + C_shared[ + ax00 * 16 : ax00 * 16 + 16, + ax10 * 16 : ax10 * 16 + 16, + ] + ) + src = T.match_buffer( + C_accum[ + ax00 * 16 : ax00 * 16 + 16, + ax10 * 16 : ax10 * 16 + 16, + ], + [16, 16], + dtype="float32", + scope="wmma.accumulator", + offset_factor=16, + ) + tgt = T.match_buffer( + C_shared[ + ax00 * 16 : ax00 * 16 + 16, + ax10 * 16 : ax10 * 16 + 16, + ], + [16, 16], + dtype="float32", + strides=[s1, s0], + scope="shared.dyn", + offset_factor=16, + ) + T.evaluate( + T.tvm_store_matrix_sync( + src.data, + 16, + 16, + 16, + src.elem_offset // 256 + + src.elem_offset % 256 // 16, + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + tgt.data, + tgt.elem_offset, + s1 * 16, + 2, + dtype="handle", + ), + s1, + "row_major", + dtype="handle", + ) + ) + + +@tvm.script.ir_module +class TransformedWmmaToGlobal: + @T.prim_func + def main(C: T.Buffer((1024, 1024), "float32")): + with T.block("root"): + T.reads() + T.writes(C[0:1024, 0:1024]) + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(""): + T.reads() + T.writes(C[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + C_accum = T.alloc_buffer((128, 128), scope="wmma.accumulator") + with T.block("C_global"): + T.reads(C_accum[0:128, 0:128]) + T.writes(C[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + T.block_attr({"auto_copy": 1, "vector_bytes": 16}) + C_accum_shared_dyn = T.alloc_buffer( + (8, 8, 16, 16), strides=(2048, 256, 16, 1), scope="shared.dyn" + ) + for ax0_0 in range(8): + for ax1_0 in range(8): + with T.block("wmma_store"): + T.reads( + C_accum[ + ax0_0 * 16 : ax0_0 * 16 + 16, + ax1_0 * 16 : ax1_0 * 16 + 16, + ] + ) + T.writes(C_accum_shared_dyn[ty, ax1_0, 0:16, 0:16]) + src = T.match_buffer( + C_accum[ + ax0_0 * 16 : ax0_0 * 16 + 16, + ax1_0 * 16 : ax1_0 * 16 + 16, + ], + (16, 16), + scope="wmma.accumulator", + offset_factor=16, + ) + s1 = T.int32() + s0 = T.int32() + tgt = T.match_buffer( + C_accum_shared_dyn[ty, ax1_0, 0:16, 0:16], + (16, 16), + strides=(s1, s0), + scope="shared.dyn", + offset_factor=16, + ) + T.tvm_store_matrix_sync( + src.data, + 16, + 16, + 16, + src.elem_offset // 256 + + src.elem_offset % 256 // 16, + T.tvm_access_ptr( + T.type_annotation("float32"), + tgt.data, + tgt.elem_offset, + s1 * 16, + 2, + ), + s1, + "row_major", + ) + for ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 + ) in range(16): + for ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_1 + ) in T.thread_binding(8, thread="threadIdx.y"): + for ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_2 + ) in T.thread_binding(32, thread="threadIdx.x"): + for ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_3 in ( + T.vectorized(4) + ): + C[ + bx * 128 + + ( + ax0_0 * 16 + + ( + ( + ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 + * 8 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_1 + ) + * 32 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_2 + ) + * 4 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_3 + ) + // 16 + % 16 + ), + by * 128 + + ( + ( + ( + ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 + * 8 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_1 + ) + * 32 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_2 + ) + * 4 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_3 + ) + // 16 + // 16 + % 8 + * 16 + + ( + ( + ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 + * 8 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_1 + ) + * 32 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_2 + ) + * 4 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_3 + ) + % 16 + ), + ] = C_accum_shared_dyn[ + ( + ( + ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 + * 8 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_1 + ) + * 32 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_2 + ) + * 4 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_3 + ) + // 16 + // 16 + // 8 + % 8, + ( + ( + ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 + * 8 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_1 + ) + * 32 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_2 + ) + * 4 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_3 + ) + // 16 + // 16 + % 8, + ( + ( + ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 + * 8 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_1 + ) + * 32 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_2 + ) + * 4 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_3 + ) + // 16 + % 16, + ( + ( + ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 + * 8 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_1 + ) + * 32 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_2 + ) + * 4 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_3 + ) + % 16, + ] + + +@tvm.script.ir_module +class TransformedWmmaToGlobalWithFusion: + @T.prim_func + def main(A: T.Buffer((1024,), "float32"), C: T.Buffer((1024, 1024), "float32")) -> None: + s0 = T.int32() + s1 = T.int32() + # body + with T.block("root"): + T.reads(A[0:1024]) + T.writes(C[0:1024, 0:1024]) + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + T.reads(A[bx * 128 : bx * 128 + 128]) + T.writes(C[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + C_accum = T.alloc_buffer( + [128, 128], dtype="float32", scope="wmma.accumulator" + ) + with T.block("C_global"): + T.reads(C_accum[0:128, 0:128], A[bx * 128 : bx * 128 + 128]) + T.writes(C[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + T.block_attr({"auto_copy": 1, "vector_bytes": 16}) + C_accum_shared_dyn = T.alloc_buffer( + (8, 8, 16, 16), strides=(2048, 256, 16, 1), scope="shared.dyn" + ) + for ax0_0 in range(8): + for ax1_0 in range(8): + with T.block("wmma_store"): + T.reads( + C_accum[ + ax0_0 * 16 : ax0_0 * 16 + 16, + ax1_0 * 16 : ax1_0 * 16 + 16, + ] + ) + T.writes(C_accum_shared_dyn[ty, ax1_0, 0:16, 0:16]) + src = T.match_buffer( + C_accum[ + ax0_0 * 16 : ax0_0 * 16 + 16, + ax1_0 * 16 : ax1_0 * 16 + 16, + ], + (16, 16), + scope="wmma.accumulator", + offset_factor=16, + ) + s1 = T.int32() + s0 = T.int32() + tgt = T.match_buffer( + C_accum_shared_dyn[ty, ax1_0, 0:16, 0:16], + (16, 16), + strides=(s1, s0), + scope="shared.dyn", + offset_factor=16, + ) + T.tvm_store_matrix_sync( + src.data, + 16, + 16, + 16, + src.elem_offset // 256 + + src.elem_offset % 256 // 16, + T.tvm_access_ptr( + T.type_annotation("float32"), + tgt.data, + tgt.elem_offset, + s1 * 16, + 2, + ), + s1, + "row_major", + ) + for ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 + ) in range(16): + for ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_1 + ) in T.thread_binding(8, thread="threadIdx.y"): + for ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_2 + ) in T.thread_binding(32, thread="threadIdx.x"): + for ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_3 in ( + T.vectorized(4) + ): + C[ + bx * 128 + + ( + ax0_0 * 16 + + ( + ( + ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 + * 8 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_1 + ) + * 32 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_2 + ) + * 4 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_3 + ) + // 16 + % 16 + ), + by * 128 + + ( + ( + ( + ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 + * 8 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_1 + ) + * 32 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_2 + ) + * 4 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_3 + ) + // 16 + // 16 + % 8 + * 16 + + ( + ( + ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 + * 8 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_1 + ) + * 32 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_2 + ) + * 4 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_3 + ) + % 16 + ), + ] = ( + C_accum_shared_dyn[ + ( + ( + ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 + * 8 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_1 + ) + * 32 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_2 + ) + * 4 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_3 + ) + // 16 + // 16 + // 8 + % 8, + ( + ( + ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 + * 8 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_1 + ) + * 32 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_2 + ) + * 4 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_3 + ) + // 16 + // 16 + % 8, + ( + ( + ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 + * 8 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_1 + ) + * 32 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_2 + ) + * 4 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_3 + ) + // 16 + % 16, + ( + ( + ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 + * 8 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_1 + ) + * 32 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_2 + ) + * 4 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_3 + ) + % 16, + ] + + A[ + bx * 128 + + ( + ax0_0 * 16 + + ( + ( + ( + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_0 + * 8 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_1 + ) + * 32 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_2 + ) + * 4 + + ty_cache_ax1_0_cache_ax0_1_cache_ax1_1_cache_fused_3 + ) + // 16 + % 16 + ) + ] + ) + + +def _check(original, transformed): + mod = tvm.tir.transform.LowerAutoCopy()(original) + tvm.ir.assert_structural_equal(mod, transformed, True) + + +def test_coalesce_vectorize(): + _check(GlobalToShared, TransformedGlobalToShared) + + +def test_inverse(): + _check(SharedToGlobal, TransformedSharedToGlobal) + + +def test_local_stage(): + _check(GlobalToSharedWithLocalStage, TransformedGlobalToSharedWithLocalStage) + + +def test_rewrite_shared_to_wmma(): + _check(SharedToWmma, TransformedSharedToWmma) + + +def test_rewrite_wmma_to_shared(): + _check(WmmaToShared, TransformedWmmaToShared) + + +def test_rewrite_wmma_to_global(): + _check(WmmaToGlobal, TransformedWmmaToGlobal) + + +def verify_single_allocation(stmt, alloc_size=None): + num_alloc = [0] + alloc_extents = [] + + def verify(n): + if ( + isinstance(n, tvm.tir.Block) + and n.alloc_buffers is not None + and (True in ((buf.scope() == "shared.dyn") for buf in n.alloc_buffers)) + ): + num_alloc[0] += len(n.alloc_buffers) + for buf in n.alloc_buffers: + alloc_extents.append(buf.shape) + + tvm.tir.stmt_functor.post_order_visit(stmt, verify) + assert num_alloc[0] == 1 + + if alloc_size: + + def prod(arr): + ret = 1 + for element in arr: + ret *= element + return ret + + assert prod(alloc_extents[0]) == alloc_size + + +def test_auto_padding(): + mod = tvm.tir.transform.LowerAutoCopy()(Transpose) + mod = tvm.tir.transform.FlattenBuffer()(mod) + verify_single_allocation(mod["main"].body, 16 * 130) + + +def test_rewrite_wmma_to_global_fusion(): + _check(WmmaToGlobalWithFusion, TransformedWmmaToGlobalWithFusion) + + +if __name__ == "__main__": + test_coalesce_vectorize() + test_inverse() + test_local_stage() + test_rewrite_shared_to_wmma() + test_rewrite_wmma_to_shared() + test_rewrite_wmma_to_global() + test_auto_padding() + test_rewrite_wmma_to_global_fusion()