From 267982fc912a784b0c12b07f89686ca1643c889b Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Thu, 21 Oct 2021 22:14:05 +0800 Subject: [PATCH 1/7] reuse shared dyn --- src/target/source/codegen_cuda.cc | 2 +- ...merge_dynamic_shared_memory_allocations.cc | 494 +++++++++++++++++- ...merge_dynamic_shared_memory_allocations.py | 28 +- 3 files changed, 494 insertions(+), 30 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 0aad18ffb6f9..a52564c34a68 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -525,7 +525,7 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) { const std::string& sync = op->args[0].as()->value; if (sync == "warp") { // DO nothing. - } else if (sync == "shared") { + } else if (sync == "shared" || sync == "shared.dyn") { this->PrintIndent(); this->stream << "__syncthreads();\n"; } else if (sync == "global") { diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index e8865b260dc1..fa036bb76a7b 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -27,6 +27,8 @@ #include #include +#include +#include #include #include @@ -36,46 +38,250 @@ namespace tvm { namespace tir { +using runtime::StorageRank; +using runtime::StorageScope; + bool IsDynamicSharedMemory(Var buffer_var) { - auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var)); + StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var)); return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn"; } +/*! + * \brief collect the mapping from the buffer var to its allocate + */ class AllocateCollector : public StmtExprVisitor { public: void VisitStmt_(const AllocateNode* op) final { if (IsDynamicSharedMemory(op->buffer_var)) { - dyn_shmem_allocs_.insert(op); + dyn_shmem_allocs_[op->buffer_var.get()] = op; } StmtExprVisitor::VisitStmt_(op); } + // The mapping from the original buffer var to its allocate + std::unordered_map dyn_shmem_allocs_; +}; + +// Find a linear pattern of storage access +// Used for liveness analysis. +// Composite scopes(loop/thread_launch/IfThen) is represented by two points: +// before_scope -> scope_body -> after_scope +// +// The linear_seq_ stores before_scope and after_scope. +// The access to the arrays are stored at the after_scope point. +// +// Define "scope" as the body of For/thread_launch/IfThenElse +// This pass tries to detect last point that we need to keep memory +// alive under the same scope as allocate. +// The storage need to be kept alive between allocate and last access. +// The free point is only inserted at the same scope of allocate. +// +class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { + public: + /*! \brief record the touch hist of statment. */ + struct StmtEntry { + // The statment + const Object* stmt; + // The index in the linear_seq_ to point to end of the nested scope. + // This is only set to non-zero if stmt is a nested scope. + // if offset > 0, means this is the begin, the end entry is current_index + offset + // if offset < 0, means this is the end, the begin entry is current_index + offset + int64_t scope_pair_offset{0}; + // The buffer variables this statment touched. + std::vector touched; + }; + // The scope of each allocation + struct AllocEntry { + // scope level + size_t level{0}; + // allocation stmt + const AllocateNode* alloc{nullptr}; + }; - std::unordered_set dyn_shmem_allocs_; + void VisitStmt_(const AllocateNode* op) final { + size_t level = scope_.size(); + const VarNode* buf = op->buffer_var.get(); + alloc_info_[buf].alloc = op; + alloc_info_[buf].level = level; + StmtExprVisitor::VisitStmt_(op); + } + void VisitStmt_(const StoreNode* op) final { + scope_.push_back(StmtEntry()); + // visit subexpr + StmtExprVisitor::VisitStmt_(op); + // Add write access. + const VarNode* buf = op->buffer_var.get(); + auto it = alloc_info_.find(buf); + if (it != alloc_info_.end() && it->second.alloc) { + ICHECK_LT(it->second.level, scope_.size()); + if (IsDynamicSharedMemory(GetRef(buf))) { + scope_[it->second.level].touched.push_back(buf); + } + } + StmtEntry e = scope_.back(); + scope_.pop_back(); + if (e.touched.size() != 0) { + e.stmt = op; + linear_seq_.push_back(e); + } + } + void VisitStmt_(const EvaluateNode* op) final { + scope_.push_back(StmtEntry()); + // visit subexpr + StmtExprVisitor::VisitStmt_(op); + StmtEntry e = scope_.back(); + scope_.pop_back(); + if (e.touched.size() != 0) { + e.stmt = op; + linear_seq_.push_back(e); + } + } + void VisitExpr_(const LoadNode* op) final { + // Add write access. + StmtExprVisitor::VisitExpr_(op); + const VarNode* buf = op->buffer_var.get(); + auto it = alloc_info_.find(buf); + if (it != alloc_info_.end() && it->second.alloc) { + ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; + if (IsDynamicSharedMemory(GetRef(buf))) { + scope_[it->second.level].touched.push_back(buf); + } + } + } + void VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::address_of())) { + const LoadNode* l = op->args[0].as(); + this->VisitExpr(l->index); + } else { + StmtExprVisitor::VisitExpr_(op); + } + } + void VisitExpr_(const VarNode* buf) final { + // Directly reference to the variable count as a read. + auto it = alloc_info_.find(buf); + if (it != alloc_info_.end() && it->second.alloc) { + ICHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint; + if (IsDynamicSharedMemory(GetRef(buf))) { + scope_[it->second.level].touched.push_back(buf); + } + } + } + template + void VisitNewScope(const T* op) { + scope_.push_back(StmtEntry()); + StmtEntry e; + e.stmt = op; + int64_t begin_index = static_cast(linear_seq_.size()); + // before scope. + linear_seq_.push_back(e); + StmtExprVisitor::VisitStmt_(op); + // after scope. + e.touched = std::move(scope_.back().touched); + scope_.pop_back(); + int64_t end_index = static_cast(linear_seq_.size()); + ICHECK_GT(end_index, begin_index); + e.scope_pair_offset = begin_index - end_index; + linear_seq_.push_back(e); + // record the pointer to end index. + ICHECK_NE(end_index, 0U); + linear_seq_[begin_index].scope_pair_offset = end_index - begin_index; + } + void VisitStmt_(const AttrStmtNode* op) final { + // Only record the outer most thread extent. + if (op->attr_key == attr::thread_extent && !in_thread_env_) { + in_thread_env_ = true; + VisitNewScope(op); + in_thread_env_ = false; + } else if (op->attr_key == attr::extern_scope) { + VisitNewScope(op); + } else if (op->attr_key == attr::virtual_thread) { + VisitNewScope(op); + } else { + StmtExprVisitor::VisitStmt_(op); + } + } + void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); } + + void VisitStmt_(const ForNode* op) final { VisitNewScope(op); } + + void VisitStmt_(const WhileNode* op) final { VisitNewScope(op); } + + void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); } + + // linearized access sequence. + std::vector linear_seq_; + // The storage scope of each buffer + std::unordered_map alloc_info_; + + private: + // Whether already in thread env. + bool in_thread_env_{false}; + // The scope stack. + std::vector scope_; }; +/*! + * \brief merge the buffers whose live range has no intersection and rewrite the body + */ class DynamicSharedMemoryRewriter : public StmtExprMutator { public: explicit DynamicSharedMemoryRewriter( - const std::unordered_set& dyn_shmem_allocs) + const std::unordered_map& dyn_shmem_allocs) : dyn_shmem_allocs_{dyn_shmem_allocs} {} + /*! + * \brief plan the memory reuse for all the buffer allocated in the statement + * @param stmt the statement + */ + void PlanReuse(Stmt stmt) { + DynSharedMemLinearAccessPatternFinder finder; + finder(stmt); + this->LivenessAnalysis(finder.linear_seq_); + this->PlanMemory(finder.linear_seq_); + } + Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent && !allocated) { + if (op->attr_key == attr::thread_extent && !allocated_) { // Allocate one dynamic shared memory allocation at the beginning of thread scope - int align = 1; - for (const auto& alloc : dyn_shmem_allocs_) { - ICHECK_EQ(alloc->dtype.lanes(), 1) << "vector dtype allocation not supported."; - align = std::max(align, alloc->dtype.bytes()); + int max_layer_num = 0; + std::vector all_entry; + for (const auto& e : const_free_map_) { + all_entry.push_back(e.second); + } + for (const StorageEntry* e : sym_free_list_) { + all_entry.push_back(e); + } + for (const StorageEntry* e : all_entry) { + max_layer_num = std::max(max_layer_num, static_cast(e->allocs.size())); + } + // calculate align for each layer of each storage entry. + std::vector align(max_layer_num, 0); + for (const StorageEntry* e : all_entry) { + for (int i = 0; i < static_cast(e->allocs.size()); i++) { + for (const VarNode* buffer : e->allocs[i]) { + const AllocateNode* alloc = dyn_shmem_allocs_[buffer]; + align[i] = std::max(align[i], alloc->dtype.bytes()); + } + } } - for (const auto& alloc : dyn_shmem_allocs_) { - ICHECK_EQ(alloc->extents.size(), 1); - buffer_byte_offsets_[alloc->buffer_var.get()] = merged_alloc_size_; - merged_alloc_size_ += alloc->extents[0] * align; + // calculate offset for each buffer based on the align of each layer + for (const StorageEntry* e : all_entry) { + PrimExpr max_inner_offset = 0; + for (int i = 0; i < static_cast(e->allocs.size()); i++) { + PrimExpr inner_offset = 0; + for (const VarNode* buffer : e->allocs[i]) { + const AllocateNode* alloc = dyn_shmem_allocs_[buffer]; + buffer_byte_offsets_[buffer] = merged_alloc_size_ + inner_offset; + inner_offset += alloc->extents[0] * alloc->dtype.bytes(); + inner_offset += indexmod(align[i] - indexmod(inner_offset, align[i]), align[i]); + } + max_inner_offset = max(max_inner_offset, inner_offset); + } + merged_alloc_size_ += max_inner_offset; } - allocated = true; - auto new_body = Allocate(merged_buf_var_, DataType::UInt(8), {merged_alloc_size_}, - const_true(), StmtExprMutator::VisitStmt(op->body)); + allocated_ = true; + Allocate new_body(merged_buf_var_, DataType::UInt(8), {merged_alloc_size_}, const_true(), + StmtExprMutator::VisitStmt(op->body)); return AttrStmt(op->node, op->attr_key, op->value, new_body, op->span); } return StmtMutator::VisitStmt_(op); @@ -90,8 +296,8 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { PrimExpr VisitExpr_(const LoadNode* op) final { if (IsDynamicSharedMemory(op->buffer_var)) { - auto offset = GetBufferOffset(op->buffer_var, op->dtype); - auto index = StmtExprMutator::VisitExpr(op->index); + PrimExpr offset = GetBufferOffset(op->buffer_var, op->dtype); + PrimExpr index = StmtExprMutator::VisitExpr(op->index); return Load(op->dtype, merged_buf_var_, offset + index, op->predicate, op->span); } return StmtExprMutator::VisitExpr_(op); @@ -99,33 +305,271 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { Stmt VisitStmt_(const StoreNode* op) final { if (IsDynamicSharedMemory(op->buffer_var)) { - auto offset = GetBufferOffset(op->buffer_var, op->value->dtype); - auto index = StmtExprMutator::VisitExpr(op->index); - auto value = StmtExprMutator::VisitExpr(op->value); + PrimExpr offset = GetBufferOffset(op->buffer_var, op->value->dtype); + PrimExpr index = StmtExprMutator::VisitExpr(op->index); + PrimExpr value = StmtExprMutator::VisitExpr(op->value); return Store(merged_buf_var_, value, offset + index, op->predicate, op->span); } return StmtExprMutator::VisitStmt_(op); } + PrimExpr VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::tvm_access_ptr())) { + ICHECK_EQ(op->args.size(), 5U); + DataType dtype = op->args[0].dtype(); + Var buffer = Downcast(op->args[1]); + if (!IsDynamicSharedMemory(buffer)) { + return StmtExprMutator::VisitExpr_(op); + } + PrimExpr extra_offset = GetBufferOffset(buffer, dtype); + + PrimExpr offset = this->VisitExpr(op->args[2]); + PrimExpr extent = this->VisitExpr(op->args[3]); + return Call(op->dtype, op->op, + {op->args[0], merged_buf_var_, extra_offset + offset, extent, op->args[4]}); + } else { + return StmtExprMutator::VisitExpr_(op); + } + } + private: PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) { auto it = buffer_byte_offsets_.find(buffer_var.get()); - ICHECK(it != buffer_byte_offsets_.end()); + ICHECK(it != buffer_byte_offsets_.end()) << buffer_var; return indexdiv(it->second, dtype.bytes()); } + using StmtEntry = DynSharedMemLinearAccessPatternFinder::StmtEntry; + + + struct StorageEntry { + // The constant size of the buffer in bits, only used if it is constant + uint64_t const_nbits{0}; + // Allocs that shares this entry. + // The inner vector means a "layer" + // For example, it we need to allocate C in the memory of A and B: + // | A: 4096 bytes | B: 4096 bytes | + // | C: 8192 bytes | + // Then the allocs = {{A, B}, {C}} + std::vector> allocs; + }; + + // Event entry in liveness analysis + struct EventEntry { + // variables we generate + std::vector gen; + // variables we kill + std::vector kill; + }; + + /*! + * \brief Liveness analysis to find gen and kill point of each variable. + * \param seq the linear pattern of storage access + */ + void LivenessAnalysis(const std::vector& seq) { + // find kill point, do a reverse linear scan. + std::unordered_set touched; + for (size_t i = seq.size(); i != 0; --i) { + const StmtEntry& s = seq[i - 1]; + for (const VarNode* buffer : s.touched) { + if (!touched.count(buffer)) { + touched.insert(buffer); + event_map_[s.stmt].kill.push_back(buffer); + } + } + } + // find gen point, do forward scan + touched.clear(); + for (size_t i = 0; i < seq.size(); ++i) { + int64_t offset = seq[i].scope_pair_offset; + if (offset < 0) continue; + const StmtEntry& s = seq[i + offset]; + for (const VarNode* buffer : s.touched) { + if (!touched.count(buffer)) { + touched.insert(buffer); + event_map_[s.stmt].gen.push_back(buffer); + } + } + } + } + + /*! + * \brief Memory plan algorithm + * \param seq the linear pattern of storage access + * \param alloc_info + */ + void PlanMemory(const std::vector& seq) { + std::unordered_set inplace_flag; + + for (size_t i = 0; i < seq.size(); ++i) { + auto it = event_map_.find(seq[i].stmt); + // scope_pair_offset >= 0 means it is either + // - leaf stmt(offset = 0) + // - beginning of scope(offset < 0) + // In both cases, we need to handle the gen event correctly + if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) { + for (const VarNode* var : it->second.gen) { + ICHECK(dyn_shmem_allocs_.count(var)); + const AllocateNode* alloc = dyn_shmem_allocs_[var]; + StorageEntry* dst_entry = FindAlloc(alloc); + alloc_map_[var] = dst_entry; + } + } + // scope_pair_offset <= 0 means it is either + // - leaf stmt(offset = 0) + // - end of scope(offset < 0) + // In both cases, we need to handle the kill event correctly + if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) { + for (const VarNode* var : it->second.kill) { + this->Free(var); + } + } + } + } + /*! + * \brief Allocate new storage entry. + * \param op the allocate node + * \param the size of the allocation in bits + * \return the new storage entry + */ + StorageEntry* NewAlloc(const AllocateNode* op, size_t const_nbits) { + ICHECK(op != nullptr); + // Re-use not successful, allocate a new buffer. + std::unique_ptr entry(new StorageEntry()); + entry->allocs.push_back({op->buffer_var.get()}); + entry->const_nbits = const_nbits; + StorageEntry* e = entry.get(); + alloc_vec_.emplace_back(std::move(entry)); + return e; + } + /*! + * \brief find the storage entry in the free list for the allocate + * \param op the allocate node + * \return the storage entry + */ + StorageEntry* FindAlloc(const AllocateNode* op) { + ICHECK(op != nullptr); + // skip plan for local variable, + // compiler can do a better job with register allocation. + const uint64_t match_range = 16; + uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes(); + uint64_t const_nbits = static_cast(op->constant_allocation_size() * op_elem_bits); + // disable reuse of small arrays, they will be lowered to registers in LLVM + // This rules only apply if we are using non special memory + if (const_nbits > 0 && const_nbits <= 32) { + return NewAlloc(op, const_nbits); + } + + if (const_nbits != 0) { + // constant allocation. + auto begin = const_free_map_.lower_bound(0); + auto mid = const_free_map_.lower_bound(const_nbits); + auto end = const_free_map_.upper_bound(const_nbits * match_range); + // Start looking at the buffer that is bigger than the required size first. + // If we find one, directly allocate the buffer in its location and remove its entry in the + // free list + for (auto it = mid; it != end; ++it) { + StorageEntry* e = it->second; + e->const_nbits = std::max(const_nbits, e->const_nbits); + const_free_map_.erase(it); + return e; + } + // Then start looking at smaller buffers. + // Keep collecting the buffer until the sum of their size exceeds the buffer to allocate + // and finally free all these entry in the free list + std::vector::iterator> delete_it; + // the alloc list for the new entry + std::vector> reuse_allocs; + uint64_t mem_ct = 0; + for (auto it = mid; it != begin;) { + --it; + if (mem_ct + it->second->const_nbits <= const_nbits) { + delete_it.push_back(it); + mem_ct += it->second->const_nbits; + int n = it->second->allocs.size(); + if (n > static_cast(reuse_allocs.size())) { + reuse_allocs.resize(n, {}); + } + for (int i = 0; i < n; i++) { + for (const VarNode* alloc : it->second->allocs[i]) { + reuse_allocs[i].push_back(alloc); + } + } + } + } + reuse_allocs.push_back({op->buffer_var.get()}); + if (mem_ct != 0) { + std::unique_ptr entry(new StorageEntry()); + entry->const_nbits = std::max(const_nbits, mem_ct); + entry->allocs = reuse_allocs; + for (auto it : delete_it) { + const_free_map_.erase(it); + } + StorageEntry* e = entry.get(); + alloc_vec_.emplace_back(std::move(entry)); + return e; + } + } else { + // if its symbolic allocation, just arbitrarily choose one entry to fit in because we don't + // know its actual size + for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) { + StorageEntry* e = *it; + sym_free_list_.erase(it); + return e; + } + } + return NewAlloc(op, const_nbits); + } + + /*! + * \brief add the storage entry to the buffer var into the free list. + * \param var the buffer var + */ + void Free(const VarNode* var) { + auto it = alloc_map_.find(var); + ICHECK(it != alloc_map_.end()); + StorageEntry* e = it->second; + ICHECK_NE(e->allocs.size(), 0U); + + // disable reuse of small arrays + if (e->const_nbits > 0 && e->const_nbits <= 32) return; + + // normal free. + if (e->const_nbits != 0) { + const_free_map_.insert({e->const_nbits, e}); + } else { + sym_free_list_.push_back(e); + } + } + // The var for the merged buffer Var merged_buf_var_{"buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)), "shared.dyn")}; - std::unordered_set dyn_shmem_allocs_; + // The mapping from the original buffer var to its allocate + std::unordered_map dyn_shmem_allocs_; + // The size of the merged buffer PrimExpr merged_alloc_size_{0}; + // The mapping from the original buffer var to its offset in the merged buffer std::unordered_map buffer_byte_offsets_; - bool allocated{false}; + // The flag indicating whether the merged buffer has been allocated + bool allocated_{false}; + // Locations of free ops. + std::unordered_map event_map_; + // constant size free map. + std::multimap const_free_map_; + // symbolic free list, for non constant items. + std::list sym_free_list_; + // The allocation assign map + std::unordered_map alloc_map_; + // The allocations + std::vector> alloc_vec_; }; Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) { AllocateCollector collector; collector(stmt); if (collector.dyn_shmem_allocs_.size() > 1) { - return DynamicSharedMemoryRewriter(collector.dyn_shmem_allocs_)(std::move(stmt)); + DynamicSharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_); + rewriter.PlanReuse(stmt); + return rewriter(std::move(stmt)); } return stmt; } diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py index cc78b84f9b4e..2cbd3acb8c9b 100644 --- a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -82,13 +82,14 @@ def test_matmul_ir(A, B, C): # Create a dynamic shared memory for the accumulation. # This is for testing merging dynamic shared memory alloctions with different data type. # In practice, there is no need to allocate a shared memory for C. + C_local = ib.allocate(C.dtype, (1), scope="local", name="C_local") C_sh = ib.allocate(C.dtype, (block, block), scope="shared.dyn", name="C_sh") # fp32 A_ptr = ib.buffer_ptr(A) B_ptr = ib.buffer_ptr(B) C_ptr = ib.buffer_ptr(C) - C_sh[ty, tx] = 0.0 + C_local[0] = 0.0 with ib.for_range(0, n // block, name="i") as i: A_sh[ty, tx] = A_ptr[by * block + ty, i * block + tx] @@ -96,10 +97,10 @@ def test_matmul_ir(A, B, C): ib.emit(syncthread()) with ib.for_range(0, block, name="k") as k: - C_sh[ty, tx] += cast(A_sh[ty, k] * B_sh[k, tx], "float32") - + C_local[0] += cast(A_sh[ty, k] * B_sh[k, tx], "float32") ib.emit(syncthread()) + C_sh[ty, tx] = C_local[0] C_ptr[by * block + ty, bx * block + tx] = C_sh[ty, tx] return ib.get() @@ -113,7 +114,8 @@ def test_matmul_ir(A, B, C): ) s = te.create_schedule(C.op) mod = run_passes(s, [A, B, C]) - expected_alloc_size = block * block * 3 * 4 + # C can be allocated at the start of A, so we only need to allocate 2 block * block memory with dtype = float16 + expected_alloc_size = block * block * 4 verify_single_allocation(mod["main"].body, expected_alloc_size) def check_target(target): @@ -248,6 +250,24 @@ def test_device_ir(A, B, C, D): # merged allocation # allocate(buf_dyn_shmem: Pointer(shared.dyn uint8), uint8, [((n_dyn*4) + 256)]); verify_single_allocation(mod["main"].body) + def check_target(target): + if not tvm.testing.device_enabled(target): + return + + fadd = tvm.build(s, [A, B, C, D], target) + dev = tvm.device(target, 0) + + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) + c = tvm.nd.array(np.random.uniform(size=n).astype(C.dtype), dev) + d = tvm.nd.array(np.zeros((n,), dtype=D.dtype), dev) + fadd(a, b, c, d) + tvm.testing.assert_allclose( + d.numpy(), a.numpy() + b.numpy() +c.numpy(), 1e-4, 1e-4 + ) + + for target in ["cuda", "nvptx"]: + check_target(target) if __name__ == "__main__": From e64e4b1699692fd062241f98236026394bb08840 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Thu, 21 Oct 2021 22:22:02 +0800 Subject: [PATCH 2/7] minor --- src/driver/driver_api.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 24cae798988e..3ddf45e0bb8a 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -590,7 +590,6 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(BindTarget(target)); mixed_pass_list.push_back(tir::transform::VerifyMemory()); - mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); if (ShouldAnnotateEntryFunc(target, mixed_mod)) { mixed_pass_list.push_back(AnnotateEntryFunc(true)); @@ -603,6 +602,8 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) } mixed_pass_list.push_back(tir::transform::ThreadSync("shared")); + mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn")); + mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); mixed_pass_list.push_back(tir::transform::InferFragment()); mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); From 430ae5d971ac7990e008a7e9d0ebbb3f81851265 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Thu, 21 Oct 2021 22:34:11 +0800 Subject: [PATCH 3/7] format --- ..._tir_transform_merge_dynamic_shared_memory_allocations.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py index 2cbd3acb8c9b..5c5a17e9181c 100644 --- a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -250,6 +250,7 @@ def test_device_ir(A, B, C, D): # merged allocation # allocate(buf_dyn_shmem: Pointer(shared.dyn uint8), uint8, [((n_dyn*4) + 256)]); verify_single_allocation(mod["main"].body) + def check_target(target): if not tvm.testing.device_enabled(target): return @@ -262,9 +263,7 @@ def check_target(target): c = tvm.nd.array(np.random.uniform(size=n).astype(C.dtype), dev) d = tvm.nd.array(np.zeros((n,), dtype=D.dtype), dev) fadd(a, b, c, d) - tvm.testing.assert_allclose( - d.numpy(), a.numpy() + b.numpy() +c.numpy(), 1e-4, 1e-4 - ) + tvm.testing.assert_allclose(d.numpy(), a.numpy() + b.numpy() + c.numpy(), 1e-4, 1e-4) for target in ["cuda", "nvptx"]: check_target(target) From 073970b65451f1d67f81f04bd3e87d794493af01 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Thu, 21 Oct 2021 22:43:20 +0800 Subject: [PATCH 4/7] format --- src/tir/transforms/merge_dynamic_shared_memory_allocations.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index fa036bb76a7b..4fd93e0c57f6 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -340,8 +340,6 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { } using StmtEntry = DynSharedMemLinearAccessPatternFinder::StmtEntry; - - struct StorageEntry { // The constant size of the buffer in bits, only used if it is constant uint64_t const_nbits{0}; From 96a9974217b1b183d90b95cfb9a25806f6a78ce2 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Mon, 25 Oct 2021 21:59:52 +0800 Subject: [PATCH 5/7] address comment and fix --- ...merge_dynamic_shared_memory_allocations.cc | 72 +++++++++---------- ...merge_dynamic_shared_memory_allocations.py | 59 +++++++++++++++ 2 files changed, 93 insertions(+), 38 deletions(-) diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index 4fd93e0c57f6..b685dc523f7a 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -27,12 +27,11 @@ #include #include -#include -#include #include #include #include "../../runtime/thread_storage_scope.h" +#include "../../support/arena.h" #include "ir_utils.h" namespace tvm { @@ -91,7 +90,7 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { }; // The scope of each allocation struct AllocEntry { - // scope level + // the level in the scope stack size_t level{0}; // allocation stmt const AllocateNode* alloc{nullptr}; @@ -159,7 +158,7 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { // Directly reference to the variable count as a read. auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { - ICHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint; + ICHECK_LT(it->second.level, scope_.size()); if (IsDynamicSharedMemory(GetRef(buf))) { scope_[it->second.level].touched.push_back(buf); } @@ -230,9 +229,9 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { /*! * \brief plan the memory reuse for all the buffer allocated in the statement - * @param stmt the statement + * \param stmt the statement */ - void PlanReuse(Stmt stmt) { + void PlanReuse(const Stmt& stmt) { DynSharedMemLinearAccessPatternFinder finder; finder(stmt); this->LivenessAnalysis(finder.linear_seq_); @@ -335,7 +334,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { private: PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) { auto it = buffer_byte_offsets_.find(buffer_var.get()); - ICHECK(it != buffer_byte_offsets_.end()) << buffer_var; + ICHECK(it != buffer_byte_offsets_.end()); return indexdiv(it->second, dtype.bytes()); } @@ -401,6 +400,15 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { for (size_t i = 0; i < seq.size(); ++i) { auto it = event_map_.find(seq[i].stmt); + // scope_pair_offset <= 0 means it is either + // - leaf stmt(offset = 0) + // - end of scope(offset < 0) + // In both cases, we need to handle the kill event correctly + if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) { + for (const VarNode* var : it->second.kill) { + this->Free(var); + } + } // scope_pair_offset >= 0 means it is either // - leaf stmt(offset = 0) // - beginning of scope(offset < 0) @@ -413,15 +421,6 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { alloc_map_[var] = dst_entry; } } - // scope_pair_offset <= 0 means it is either - // - leaf stmt(offset = 0) - // - end of scope(offset < 0) - // In both cases, we need to handle the kill event correctly - if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) { - for (const VarNode* var : it->second.kill) { - this->Free(var); - } - } } } /*! @@ -433,12 +432,10 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { StorageEntry* NewAlloc(const AllocateNode* op, size_t const_nbits) { ICHECK(op != nullptr); // Re-use not successful, allocate a new buffer. - std::unique_ptr entry(new StorageEntry()); + StorageEntry* entry = arena_.make(); entry->allocs.push_back({op->buffer_var.get()}); entry->const_nbits = const_nbits; - StorageEntry* e = entry.get(); - alloc_vec_.emplace_back(std::move(entry)); - return e; + return entry; } /*! * \brief find the storage entry in the free list for the allocate @@ -481,30 +478,29 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { uint64_t mem_ct = 0; for (auto it = mid; it != begin;) { --it; - if (mem_ct + it->second->const_nbits <= const_nbits) { - delete_it.push_back(it); - mem_ct += it->second->const_nbits; - int n = it->second->allocs.size(); - if (n > static_cast(reuse_allocs.size())) { - reuse_allocs.resize(n, {}); - } - for (int i = 0; i < n; i++) { - for (const VarNode* alloc : it->second->allocs[i]) { - reuse_allocs[i].push_back(alloc); - } + delete_it.push_back(it); + mem_ct += it->second->const_nbits; + int n = it->second->allocs.size(); + if (n > static_cast(reuse_allocs.size())) { + reuse_allocs.resize(n, {}); + } + for (int i = 0; i < n; i++) { + for (const VarNode* alloc : it->second->allocs[i]) { + reuse_allocs[i].push_back(alloc); } } + if (mem_ct >= const_nbits) { + break; + } } reuse_allocs.push_back({op->buffer_var.get()}); if (mem_ct != 0) { - std::unique_ptr entry(new StorageEntry()); - entry->const_nbits = std::max(const_nbits, mem_ct); - entry->allocs = reuse_allocs; + StorageEntry* e = arena_.make(); + e->const_nbits = std::max(const_nbits, mem_ct); + e->allocs = reuse_allocs; for (auto it : delete_it) { const_free_map_.erase(it); } - StorageEntry* e = entry.get(); - alloc_vec_.emplace_back(std::move(entry)); return e; } } else { @@ -557,8 +553,8 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { std::list sym_free_list_; // The allocation assign map std::unordered_map alloc_map_; - // The allocations - std::vector> alloc_vec_; + /*! \brief allocator of all the StorageEntry*/ + support::Arena arena_; }; Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) { diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py index 5c5a17e9181c..9d87b4be9724 100644 --- a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -269,7 +269,66 @@ def check_target(target): check_target(target) +def test_dyn_shared_more_dtype(): + """Test vectorized store into dynamic shared memory""" + n = 512 + A = te.placeholder((n,), name="A", dtype="int8") + B = te.placeholder((n,), name="B", dtype="int16") + + def test_device_ir(A, B, C): + n = A.shape[0] + ib = tvm.tir.ir_builder.create() + + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", n) + + A_sh = ib.allocate(A.dtype, (n,), scope="shared.dyn") # i8 + B_sh = ib.allocate(B.dtype, (n,), scope="shared.dyn") # i16 + C_sh = ib.allocate(C.dtype, (n,), scope="shared.dyn") # i32 + + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + + A_sh[tx] = Aptr[tx] + B_sh[tx] = Bptr[tx] + + C_sh[tx] = cast(A_sh[tx], "int32") + cast(B_sh[tx], "int32") + Cptr[tx] = C_sh[tx] + return ib.get() + + C = te.extern( + (n,), + [A, B], + lambda ins, outs: test_device_ir(ins[0], ins[1], outs[0]), + name="vadd", + dtype="int32", + ) + s = te.create_schedule(C.op) + + mod = run_passes(s, [A, B, C]) + print(mod["main"].body) + verify_single_allocation(mod["main"].body, n * 4) + + def check_target(target): + if not tvm.testing.device_enabled(target): + return + + fadd = tvm.build(s, [A, B, C], target) + dev = tvm.device(target, 0) + + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) + c = tvm.nd.array(np.zeros((n,), dtype=C.dtype), dev) + fadd(a, b, c) + tvm.testing.assert_allclose(c.numpy(), a.numpy().astype("float32") + b.numpy(), 1e-4, 1e-4) + + for target in ["cuda", "nvptx"]: + check_target(target) + + if __name__ == "__main__": test_matmul_dyn_shared() test_dyn_shared_vectorized_store() test_dyn_shared_reuse_and_merge() + test_dyn_shared_more_dtype() From 07464d0d537f7e8a08231c74ea180489364a0943 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Thu, 28 Oct 2021 12:14:53 +0800 Subject: [PATCH 6/7] address comment --- ...merge_dynamic_shared_memory_allocations.cc | 23 +++++++++---------- ...merge_dynamic_shared_memory_allocations.py | 1 - 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index b685dc523f7a..7310e7431325 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -62,30 +62,29 @@ class AllocateCollector : public StmtExprVisitor { // Find a linear pattern of storage access // Used for liveness analysis. -// Composite scopes(loop/thread_launch/IfThen) is represented by two points: -// before_scope -> scope_body -> after_scope -// -// The linear_seq_ stores before_scope and after_scope. -// The access to the arrays are stored at the after_scope point. +// "linear" means fitting a complex access pattern into an array of StmtEntry // // Define "scope" as the body of For/thread_launch/IfThenElse +// Composite scopes(loop/thread_launch/IfThen) is represented by three StmtEntry: +// before_scope -> scope_body -> after_scope +// // This pass tries to detect last point that we need to keep memory -// alive under the same scope as allocate. -// The storage need to be kept alive between allocate and last access. -// The free point is only inserted at the same scope of allocate. +// alive under the same scope as Allocate. +// The storage need to be kept alive between Allocate and last access. +// The free point is only inserted at the same scope of Allocate. // class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { public: - /*! \brief record the touch hist of statment. */ + /*! \brief record the touch list of statment. */ struct StmtEntry { - // The statment + // The statement const Object* stmt; // The index in the linear_seq_ to point to end of the nested scope. // This is only set to non-zero if stmt is a nested scope. // if offset > 0, means this is the begin, the end entry is current_index + offset // if offset < 0, means this is the end, the begin entry is current_index + offset int64_t scope_pair_offset{0}; - // The buffer variables this statment touched. + // The buffer variables this statement touched. std::vector touched; }; // The scope of each allocation @@ -237,6 +236,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { this->LivenessAnalysis(finder.linear_seq_); this->PlanMemory(finder.linear_seq_); } + private: Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent && !allocated_) { @@ -331,7 +331,6 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { } } - private: PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) { auto it = buffer_byte_offsets_.find(buffer_var.get()); ICHECK(it != buffer_byte_offsets_.end()); diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py index 9d87b4be9724..fffcb40ce95d 100644 --- a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -307,7 +307,6 @@ def test_device_ir(A, B, C): s = te.create_schedule(C.op) mod = run_passes(s, [A, B, C]) - print(mod["main"].body) verify_single_allocation(mod["main"].body, n * 4) def check_target(target): From eed8d5e25e149e97390bff8cdaa3ad536d59f9e1 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Sun, 31 Oct 2021 01:12:56 +0800 Subject: [PATCH 7/7] address comment --- src/tir/transforms/merge_dynamic_shared_memory_allocations.cc | 4 ++-- ...t_tir_transform_merge_dynamic_shared_memory_allocations.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index 7310e7431325..f3ff1f37a5da 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -75,7 +75,7 @@ class AllocateCollector : public StmtExprVisitor { // class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { public: - /*! \brief record the touch list of statment. */ + /*! \brief record the touch list of statement. */ struct StmtEntry { // The statement const Object* stmt; @@ -236,8 +236,8 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { this->LivenessAnalysis(finder.linear_seq_); this->PlanMemory(finder.linear_seq_); } - private: + private: Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent && !allocated_) { // Allocate one dynamic shared memory allocation at the beginning of thread scope diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py index fffcb40ce95d..46d39c034454 100644 --- a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -82,7 +82,7 @@ def test_matmul_ir(A, B, C): # Create a dynamic shared memory for the accumulation. # This is for testing merging dynamic shared memory alloctions with different data type. # In practice, there is no need to allocate a shared memory for C. - C_local = ib.allocate(C.dtype, (1), scope="local", name="C_local") + C_local = ib.allocate(C.dtype, (1,), scope="local", name="C_local") C_sh = ib.allocate(C.dtype, (block, block), scope="shared.dyn", name="C_sh") # fp32 A_ptr = ib.buffer_ptr(A)