From 617d7e00fbad9f4b7695c7391fd79e0d1bc61da4 Mon Sep 17 00:00:00 2001 From: masa Date: Sat, 24 Jul 2021 20:24:44 +0900 Subject: [PATCH 01/21] add shared mem matmul test --- tests/python/unittest/test_tir_ir_builder.py | 80 ++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 0329134bb3fa..2e9f398c58fb 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -554,6 +554,85 @@ def check_target(target): check_target(target) +@tvm.testing.requires_gpu +def test_matmul_dyn_shared(): + n = 1024 + A = te.placeholder((n, n), name="A", dtype="float16") + B = te.placeholder((n, n), name="B", dtype="float16") + + def syncthread(): + return tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])) + + def test_matmul_ir(A, B, C): + ib = tvm.tir.ir_builder.create() + block = 16 + + tx = te.thread_axis("threadIdx.x") + ty = te.thread_axis("threadIdx.y") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", block) + ib.scope_attr(ty, "thread_extent", block) + ib.scope_attr(bx, "thread_extent", n / block) + ib.scope_attr(by, "thread_extent", n / block) + + A_sh = ib.allocate(A.dtype, (block, block), scope="shared") # fp16 + B_sh = ib.allocate(B.dtype, (block, block), scope="shared") # fp16 + # Create a dynamic shared memory for the accumulation. + # This is for testing merging dynamic shared memory alloctions with different data type. + # In practice, there is no need to allocate a shared memory for C. + C_sh = ib.allocate(C.dtype, (block, block), scope="shared") # fp32 + + A_ptr = ib.buffer_ptr(A) + B_ptr = ib.buffer_ptr(B) + C_ptr = ib.buffer_ptr(C) + + C_sh[ty, tx] = 0.0 + + with ib.for_range(0, n // block, name="i") as i: + A_sh[ty, tx] = A_ptr[by * block + ty, i * block + tx] + B_sh[ty, tx] = B_ptr[i * block + ty, bx * block + tx] + ib.emit(syncthread()) + + with ib.for_range(0, block, name="k") as k: + C_sh[ty, tx] += cast(A_sh[ty, k] * B_sh[k, tx], "float32") + + ib.emit(syncthread()) + + C_ptr[by * block + ty, bx * block + tx] = C_sh[ty, tx] + + return ib.get() + + C = te.extern( + A.shape, + [A, B], + lambda ins, outs: test_matmul_ir(ins[0], ins[1], outs[0]), + name="reduce", + dtype="float32", + ) + s = te.create_schedule(C.op) + + def check_target(target): + if not tvm.testing.device_enabled(target): + return + + fmatmul = tvm.build(s, [A, B, C], target) + dev = tvm.device(target, 0) + + size = (n, n) + a_np = np.random.uniform(size=size).astype(A.dtype) + b_np = np.random.uniform(size=size).astype(B.dtype) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros(size, dtype=C.dtype), dev) + fmatmul(a, b, c) + np_ref = np.dot(a_np.astype("float32"), b_np.astype("float32")) + tvm.testing.assert_allclose(c.numpy(), np_ref, 1e-4, 1e-4) + + for target in ["cuda", "nvptx"]: + check_target(target) + + if __name__ == "__main__": test_prefetch() test_if() @@ -565,3 +644,4 @@ def check_target(target): test_while_mandel() test_while_binary_search() test_dyn_shared() + test_matmul_dyn_shared() From 26bfb17ec4f6f4e21086227ecc9c124a1896f2a5 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 27 Jul 2021 04:22:18 +0900 Subject: [PATCH 02/21] Add a stub pass --- include/tvm/tir/transform.h | 5 + python/tvm/driver/build_module.py | 5 +- python/tvm/tir/transform/transform.py | 11 +++ src/driver/driver_api.cc | 1 + ...merge_dynamic_shared_memory_allocations.cc | 92 +++++++++++++++++++ tests/python/unittest/test_tir_ir_builder.py | 6 +- 6 files changed, 116 insertions(+), 4 deletions(-) create mode 100644 src/tir/transforms/merge_dynamic_shared_memory_allocations.cc diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 5ee847e2f010..2065201ff29b 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -431,6 +431,11 @@ TVM_DLL Pass LegalizePackedCalls(); */ TVM_DLL Pass FlattenBuffer(); +/*! + * A pass to merge multiple TIR-level dynamic shared memory allocations into one + */ +TVM_DLL Pass MergeDynamicSharedMemoryAllocations(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 983a40ab5b3f..a7ebc00c315f 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -161,7 +161,10 @@ def _build_for_device(input_mod, target, target_host): mod_mixed = input_mod mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed) - opt_mixed = [tvm.tir.transform.VerifyMemory()] + opt_mixed = [ + tvm.tir.transform.VerifyMemory(), + tvm.tir.transform.MergeDynamicSharedMemoryAllocations(), + ] if len(mod_mixed.functions) == 1: opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))] diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 1e5c303fa17c..42efdf2e9ba9 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -655,3 +655,14 @@ def FlattenBuffer(): The result pass """ return _ffi_api.FlattenBuffer() # type: ignore + + +def MergeDynamicSharedMemoryAllocations(): + """TODO + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.MergeDynamicSharedMemoryAllocations() # type: ignore diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 1591e875a4b3..7f6d704b544a 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -377,6 +377,7 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target Array mixed_pass_list = {BindTarget(target), tir::transform::VerifyMemory()}; + mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); if (pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value()) { mixed_pass_list.push_back(tir::transform::ThreadSync("global")); } diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc new file mode 100644 index 000000000000..c2a675840411 --- /dev/null +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file merge_dynamic_shared_memory_allocations.cc + */ +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +class AllocateCollector : public StmtExprVisitor { + public: + void VisitStmt_(const AllocateNode* op) final { + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { + dyn_shmem_allocs_.insert(op); + } + StmtExprVisitor::VisitStmt_(op); + } + + std::unordered_set dyn_shmem_allocs_; +}; + +class DynamicSharedMemoryRewriter : public StmtExprMutator { + public: + DynamicSharedMemoryRewriter(const std::unordered_set& dyn_shmem_allocs) + : dyn_shmem_allocs_{dyn_shmem_allocs} {} + + Stmt Rewrite(Stmt stmt) { return stmt; } + + PrimExpr VisitExpr_(const LoadNode* op) final { return StmtExprMutator::VisitExpr_(op); } + + Stmt VisitStmt_(const AllocateNode* op) final { return StmtExprMutator::VisitStmt_(op); } + + Stmt VisitStmt_(const StoreNode* op) final { return StmtExprMutator::VisitStmt_(op); } + + private: + Var merged_buf_var_{"buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)), "shared.dyn")}; + std::unordered_set dyn_shmem_allocs_; +}; + +Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) { + AllocateCollector collector; + collector(stmt); + return DynamicSharedMemoryRewriter(collector.dyn_shmem_allocs_).Rewrite(std::move(stmt)); +} + +namespace transform { + +Pass MergeDynamicSharedMemoryAllocations() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = MergeDynamicSharedMemoryAllocations(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.MergeDynamicSharedMemoryAllocations", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.MergeDynamicSharedMemoryAllocations") + .set_body_typed(MergeDynamicSharedMemoryAllocations); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 2e9f398c58fb..1b396316c798 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -576,12 +576,12 @@ def test_matmul_ir(A, B, C): ib.scope_attr(bx, "thread_extent", n / block) ib.scope_attr(by, "thread_extent", n / block) - A_sh = ib.allocate(A.dtype, (block, block), scope="shared") # fp16 - B_sh = ib.allocate(B.dtype, (block, block), scope="shared") # fp16 + A_sh = ib.allocate(A.dtype, (block, block), scope="shared.dyn") # fp16 + B_sh = ib.allocate(B.dtype, (block, block), scope="shared.dyn") # fp16 # Create a dynamic shared memory for the accumulation. # This is for testing merging dynamic shared memory alloctions with different data type. # In practice, there is no need to allocate a shared memory for C. - C_sh = ib.allocate(C.dtype, (block, block), scope="shared") # fp32 + C_sh = ib.allocate(C.dtype, (block, block), scope="shared.dyn") # fp32 A_ptr = ib.buffer_ptr(A) B_ptr = ib.buffer_ptr(B) From 9eab8d345bdb7b9745fb5479ac94d45cdc75c0e3 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 27 Jul 2021 05:01:54 +0900 Subject: [PATCH 03/21] add builtin for reinterprete load/store --- include/tvm/tir/builtin.h | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 61280d33f1df..8dce7bcec780 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -601,6 +601,18 @@ TVM_DLL const Op& vectorcombine(); */ TVM_DLL const Op& atomic_add(); +/*! + * \brief TODO + * reinterprete_load(ptr, "float32", idx) is equivalent to ((float*)ptr)[idx] + */ +TVM_DLL const Op& reinterprete_load(); + +/*! + * \brief TODO + * reinterprete_store(ptr, "float32", 0.0) is equivalent to (*(float*)ptr) = 0.0 + */ +TVM_DLL const Op& reinterprete_store(); + /*! \brief The kind of structure field info used in intrinsic */ enum TVMStructFieldKind : int { // array head address From 27c881a359540ac44242aa2aa3820e5102692ad3 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 27 Jul 2021 05:33:35 +0900 Subject: [PATCH 04/21] remove buitlin since Load/Store node already support reinterpret --- include/tvm/tir/builtin.h | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 8dce7bcec780..61280d33f1df 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -601,18 +601,6 @@ TVM_DLL const Op& vectorcombine(); */ TVM_DLL const Op& atomic_add(); -/*! - * \brief TODO - * reinterprete_load(ptr, "float32", idx) is equivalent to ((float*)ptr)[idx] - */ -TVM_DLL const Op& reinterprete_load(); - -/*! - * \brief TODO - * reinterprete_store(ptr, "float32", 0.0) is equivalent to (*(float*)ptr) = 0.0 - */ -TVM_DLL const Op& reinterprete_store(); - /*! \brief The kind of structure field info used in intrinsic */ enum TVMStructFieldKind : int { // array head address From f1e35fa3771fd58c2a27df09c73252eb34520d5d Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 27 Jul 2021 05:40:45 +0900 Subject: [PATCH 05/21] Add Load/Store visitor implementation --- ...merge_dynamic_shared_memory_allocations.cc | 44 ++++++++++++++++--- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index c2a675840411..e1a09ec19734 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -36,11 +36,15 @@ namespace tvm { namespace tir { +bool IsDynamicSharedMemory(Var buffer_var) { + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var)); + return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn"; +} + class AllocateCollector : public StmtExprVisitor { public: void VisitStmt_(const AllocateNode* op) final { - auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); - if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { + if (IsDynamicSharedMemory(op->buffer_var)) { dyn_shmem_allocs_.insert(op); } StmtExprVisitor::VisitStmt_(op); @@ -54,17 +58,45 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { DynamicSharedMemoryRewriter(const std::unordered_set& dyn_shmem_allocs) : dyn_shmem_allocs_{dyn_shmem_allocs} {} - Stmt Rewrite(Stmt stmt) { return stmt; } + Stmt Rewrite(Stmt stmt) { + return Allocate(merged_buf_var_, merged_buf_var_->dtype, {merged_alloc_size_}, true, + StmtExprMutator::VisitStmt(stmt)); + } - PrimExpr VisitExpr_(const LoadNode* op) final { return StmtExprMutator::VisitExpr_(op); } + Stmt VisitStmt_(const AllocateNode* op) final { + if (IsDynamicSharedMemory(op->buffer_var)) { + return StmtExprMutator::VisitStmt(op->body); + } + return StmtExprMutator::VisitStmt_(op); + } - Stmt VisitStmt_(const AllocateNode* op) final { return StmtExprMutator::VisitStmt_(op); } + PrimExpr VisitExpr_(const LoadNode* op) final { + if (IsDynamicSharedMemory(op->buffer_var)) { + auto offset = GetBufferOffset(op->buffer_var, op->dtype); + return Load(op->dtype, merged_buf_var_, offset + op->index, op->predicate, op->span); + } + return StmtExprMutator::VisitExpr_(op); + } - Stmt VisitStmt_(const StoreNode* op) final { return StmtExprMutator::VisitStmt_(op); } + Stmt VisitStmt_(const StoreNode* op) final { + if (IsDynamicSharedMemory(op->buffer_var)) { + auto offset = GetBufferOffset(op->buffer_var, op->value->dtype); + return Store(merged_buf_var_, op->value, offset + op->index, op->predicate, op->span); + } + return StmtExprMutator::VisitStmt_(op); + } private: + PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) { + auto it = buffer_offsets_.find(buffer_var.get()); + ICHECK(it != buffer_offsets_.end()); + return indexdiv(it->second, dtype.bytes()); + } + Var merged_buf_var_{"buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)), "shared.dyn")}; std::unordered_set dyn_shmem_allocs_; + PrimExpr merged_alloc_size_; + std::unordered_map buffer_offsets_; }; Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) { From 482ecd8aa51569683342457fcf935b1cf2aa77f1 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 27 Jul 2021 22:35:41 +0900 Subject: [PATCH 06/21] allocate merge first cut --- ...merge_dynamic_shared_memory_allocations.cc | 24 ++++++++++++--- tests/python/unittest/test_tir_ir_builder.py | 30 +++++++++---------- 2 files changed, 35 insertions(+), 19 deletions(-) diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index e1a09ec19734..25bf730a7739 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -59,7 +59,18 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { : dyn_shmem_allocs_{dyn_shmem_allocs} {} Stmt Rewrite(Stmt stmt) { - return Allocate(merged_buf_var_, merged_buf_var_->dtype, {merged_alloc_size_}, true, + int align = 1; + for (auto& alloc : dyn_shmem_allocs_) { + align = std::max(align, alloc->dtype.bytes()); + } + for (auto& alloc : dyn_shmem_allocs_) { + buffer_offsets_[alloc->buffer_var.get()] = merged_alloc_size_; + merged_alloc_size_ += alloc->extents[0] * align; + LOG(INFO) << "buffer offset for " << alloc->buffer_var->name_hint << " = " + << buffer_offsets_[alloc->buffer_var.get()]; + } + + return Allocate(merged_buf_var_, DataType::UInt(8), {merged_alloc_size_}, const_true(), StmtExprMutator::VisitStmt(stmt)); } @@ -73,7 +84,8 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { PrimExpr VisitExpr_(const LoadNode* op) final { if (IsDynamicSharedMemory(op->buffer_var)) { auto offset = GetBufferOffset(op->buffer_var, op->dtype); - return Load(op->dtype, merged_buf_var_, offset + op->index, op->predicate, op->span); + auto index = StmtExprMutator::VisitExpr(op->index); + return Load(op->dtype, merged_buf_var_, offset + index, op->predicate, op->span); } return StmtExprMutator::VisitExpr_(op); } @@ -81,7 +93,9 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { Stmt VisitStmt_(const StoreNode* op) final { if (IsDynamicSharedMemory(op->buffer_var)) { auto offset = GetBufferOffset(op->buffer_var, op->value->dtype); - return Store(merged_buf_var_, op->value, offset + op->index, op->predicate, op->span); + auto index = StmtExprMutator::VisitExpr(op->index); + auto value = StmtExprMutator::VisitExpr(op->value); + return Store(merged_buf_var_, value, offset + index, op->predicate, op->span); } return StmtExprMutator::VisitStmt_(op); } @@ -95,7 +109,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { Var merged_buf_var_{"buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)), "shared.dyn")}; std::unordered_set dyn_shmem_allocs_; - PrimExpr merged_alloc_size_; + PrimExpr merged_alloc_size_{0}; std::unordered_map buffer_offsets_; }; @@ -110,7 +124,9 @@ namespace transform { Pass MergeDynamicSharedMemoryAllocations() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); + LOG(INFO) << "Before: " << f; n->body = MergeDynamicSharedMemoryAllocations(std::move(n->body)); + LOG(INFO) << "After: " << f; return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.MergeDynamicSharedMemoryAllocations", {}); diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 1b396316c798..e148d4844a3e 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -557,8 +557,8 @@ def check_target(target): @tvm.testing.requires_gpu def test_matmul_dyn_shared(): n = 1024 - A = te.placeholder((n, n), name="A", dtype="float16") - B = te.placeholder((n, n), name="B", dtype="float16") + A = te.placeholder((n, n), name="A", dtype="float32") + B = te.placeholder((n, n), name="B", dtype="float32") def syncthread(): return tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])) @@ -576,12 +576,12 @@ def test_matmul_ir(A, B, C): ib.scope_attr(bx, "thread_extent", n / block) ib.scope_attr(by, "thread_extent", n / block) - A_sh = ib.allocate(A.dtype, (block, block), scope="shared.dyn") # fp16 - B_sh = ib.allocate(B.dtype, (block, block), scope="shared.dyn") # fp16 + A_sh = ib.allocate(A.dtype, (block, block), scope="shared.dyn", name="A_sh") # fp16 + B_sh = ib.allocate(B.dtype, (block, block), scope="shared.dyn", name="B_sh") # fp16 # Create a dynamic shared memory for the accumulation. # This is for testing merging dynamic shared memory alloctions with different data type. # In practice, there is no need to allocate a shared memory for C. - C_sh = ib.allocate(C.dtype, (block, block), scope="shared.dyn") # fp32 + C_sh = ib.allocate(C.dtype, (block, block), scope="shared.dyn", name="C_sh") # fp32 A_ptr = ib.buffer_ptr(A) B_ptr = ib.buffer_ptr(B) @@ -634,14 +634,14 @@ def check_target(target): if __name__ == "__main__": - test_prefetch() - test_if() - test_for() - test_cpu() - test_gpu() - test_while_vectorize() - test_while_collatz() - test_while_mandel() - test_while_binary_search() - test_dyn_shared() + # test_prefetch() + # test_if() + # test_for() + # test_cpu() + # test_gpu() + # test_while_vectorize() + # test_while_collatz() + # test_while_mandel() + # test_while_binary_search() + # test_dyn_shared() test_matmul_dyn_shared() From ce62d9e53ea467bd33d1485006aba97a02719712 Mon Sep 17 00:00:00 2001 From: masa Date: Thu, 8 Jul 2021 07:43:57 +0900 Subject: [PATCH 07/21] Remove all attr::storage_scope usage --- include/tvm/tir/stmt.h | 2 - python/tvm/contrib/hexagon.py | 16 ++++--- python/tvm/script/scope_handler.py | 2 +- python/tvm/tir/ir_builder.py | 2 - src/printer/tvmscript_printer.cc | 48 ++++++++----------- src/relay/backend/aot_executor_codegen.cc | 2 - src/target/source/codegen_c.cc | 15 ++---- src/te/operation/cross_thread_reduction.cc | 3 -- src/tir/analysis/verify_gpu_code.cc | 21 +++----- src/tir/ir/stmt.cc | 10 ---- src/tir/transforms/flatten_buffer.cc | 1 - src/tir/transforms/inject_copy_intrin.cc | 15 ++---- src/tir/transforms/inject_double_buffer.cc | 17 ++----- src/tir/transforms/ir_utils.cc | 10 ---- .../lower_device_storage_access_info.cc | 46 ++++-------------- src/tir/transforms/lower_thread_allreduce.cc | 12 +---- src/tir/transforms/lower_warp_memory.cc | 17 +------ src/tir/transforms/storage_flatten.cc | 1 - src/tir/transforms/storage_rewrite.cc | 6 +-- .../transforms/tensorcore_infer_fragment.cc | 11 +---- .../update_pointer_storage_scope.cc | 11 ----- .../transforms/update_pointer_storage_scope.h | 1 - tests/python/unittest/test_tir_ir_builder.py | 2 - .../test_tir_transform_coproc_sync.py | 4 +- ...test_tir_transform_inject_double_buffer.py | 4 +- ...est_tir_transform_inject_virtual_thread.py | 12 ++--- .../test_tir_transform_lift_attr_scope.py | 4 +- .../test_tir_transform_loop_partition.py | 4 +- .../test_tir_transform_lower_warp_memory.py | 5 +- .../test_tir_transform_storage_flatten.py | 6 +-- .../test_tir_transform_storage_rewrite.py | 24 ++++------ .../test_tir_transform_thread_sync.py | 2 +- 32 files changed, 96 insertions(+), 240 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 9997a4d95694..c41cac2a3a25 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1240,8 +1240,6 @@ constexpr const char* extern_scope = "extern_scope"; * This can hint some code generator to create a new function for compute. */ constexpr const char* compute_scope = "compute_scope"; -/*! \brief Mark storage scope of buffers */ -constexpr const char* storage_scope = "storage_scope"; /*! \brief Mark storage alignement requirement of buffers */ constexpr const char* storage_alignment = "storage_alignment"; /*! \brief Mark storage scope of realization */ diff --git a/python/tvm/contrib/hexagon.py b/python/tvm/contrib/hexagon.py index 34b37537776f..c2197af22d2a 100644 --- a/python/tvm/contrib/hexagon.py +++ b/python/tvm/contrib/hexagon.py @@ -176,23 +176,27 @@ def buf_align(var): def visit(stmt): """Collect information about VTCM buffers and their alignments.""" if isinstance(stmt, tvm.tir.AttrStmt): - if stmt.attr_key == "storage_scope" and stmt.value == "local.vtcm": - vtcm_buffers.append(stmt.node) - elif stmt.attr_key == "storage_alignment": + if stmt.attr_key == "storage_alignment": if not stmt.node in alignments: alignments[stmt.node] = [] alignments[stmt.node].append(stmt.value) + elif isinstance(stmt, tvm.tir.Allocate): + scope = stmt.buffer_var.type_annotation.storage_scope + if scope == "local.vtcm": + vtcm_buffers.append(stmt.node) + def mutate(stmt): """Insert calls to VTCM allocation and deallocation routines.""" if isinstance(stmt, tvm.tir.AttrStmt): - if stmt.attr_key == "storage_scope" and stmt.value == "local.vtcm": - vtcm_buffers.pop() - elif stmt.attr_key == "storage_alignment": + if stmt.attr_key == "storage_alignment": alignments[stmt.node].pop() return stmt if isinstance(stmt, tvm.tir.Allocate): var = stmt.buffer_var + scope = var.type_annotation.storage_scope + if scope == "local.vtcm": + vtcm_buffers.pop() if var in vtcm_buffers: is_null = tvm.tir.call_intrin("bool", tvm.ir.Op.get("tir.isnullptr"), var) throw_error = tvm.tir.call_intrin( diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py index d07209485bd4..971580343763 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/scope_handler.py @@ -113,7 +113,7 @@ def allocate(extents, dtype, scope, condition=True, span=None): body = tvm.tir.Allocate( self.buffer_var, dtype, extents, condition, self.body, span=span ) - return tvm.tir.AttrStmt(self.buffer_var, "storage_scope", scope, body, span=span) + return body super().__init__(allocate, concise_scope=True, def_symbol=True) self.buffer_var = None diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 35932540fe68..978c630b17ad 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -419,8 +419,6 @@ def allocate(self, dtype, shape, name="buf", scope=""): buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope)) if not isinstance(shape, (list, tuple, _container.Array)): shape = [shape] - if scope: - self.scope_attr(buffer_var, "storage_scope", scope) self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x)) return BufferVar(self, buffer_var, shape, dtype) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 01f79bd0c750..cc8aa48f3cd1 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -37,6 +37,7 @@ #include #include +#include "../tir/transforms/ir_utils.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" @@ -579,31 +580,6 @@ Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) { Doc TVMScriptPrinter::VisitStmt_(const AttrStmtNode* op) { Doc doc; - // merge attr with allocate when possible - if (op->node->IsInstance() && op->attr_key == "storage_scope" && - op->body->IsInstance()) { - const auto* alloc = Downcast(op->body).get(); - if (alloc->buffer_var.same_as(op->node)) { - var_not_in_headers.insert(alloc->buffer_var.get()); - if (current_num_ != num_child_ - 1) { - doc << "with tir.allocate(" << Print(alloc->extents) << ", " << PrintDType(alloc->dtype) - << ", " << Print(op->value); - if (!is_one(alloc->condition)) { - doc << ", " << Print(alloc->condition); - } - doc << ") as " << Print(op->node) << ":"; - doc << Doc::Indent(4, Doc::NewLine() << PrintBody(alloc->body)); - } else { - doc << Print(op->node) << " = tir.allocate(" << Print(alloc->extents) << ", " - << PrintDType(alloc->dtype) << ", " << Print(op->value); - if (!is_one(alloc->condition)) { - doc << ", " << Print(alloc->condition); - } - doc << ")" << Doc::NewLine() << PrintBody(alloc->body); - } - return doc; - } - } // merge attr with realize when possible if (op->node->IsInstance() && op->attr_key == "realize_scope" && op->body->IsInstance()) { @@ -681,8 +657,26 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) { } Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { - LOG(FATAL) << "TVM Script Printer Internal Error: All the Allocate should be folded with Attr"; - return Doc(); + var_not_in_headers.insert(op->buffer_var.get()); + Doc doc; + auto storage_scope = GetPtrStorageScope(op->buffer_var); + if (current_num_ != num_child_ - 1) { + doc << "with tir.allocate(" << Print(op->extents) << ", " << PrintDType(op->dtype) << ", " + << Print(storage_scope); + if (!is_one(op->condition)) { + doc << ", " << Print(op->condition); + } + doc << ") as " << Print(op->buffer_var) << ":"; + doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + } else { + doc << Print(op->buffer_var) << " = tir.allocate(" << Print(op->extents) << ", " + << PrintDType(op->dtype) << ", " << Print(storage_scope); + if (!is_one(op->condition)) { + doc << ", " << Print(op->condition); + } + doc << ")" << Doc::NewLine() << PrintBody(op->body); + } + return doc; } Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) { diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 4df38b9449ae..fd6ee27eb6be 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -625,8 +625,6 @@ class AOTExecutorCodegen : public ExprVisitor { // so we don't pay the price of allocation for every inference if (!allocated[sid]) { body = tir::Allocate(sids_table_[sid], DataType::Int(8), {size}, tir::const_true(), body); - body = tir::AttrStmt(sids_table_[sid], tir::attr::storage_scope, tir::StringImm("global"), - body); } allocated[sid] = true; } diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 99c9452975d4..8397044e8b93 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -861,12 +861,11 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) { this->PrintIndent(); int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; - const VarNode* buffer = op->buffer_var.as(); - auto it = alloc_storage_scope_.find(buffer); - if (it != alloc_storage_scope_.end()) { - std::string scope = alloc_storage_scope_.at(buffer); - PrintStorageScope(scope, stream); - } + + auto scope = GetPtrStorageScope(op->buffer_var); + alloc_storage_scope_[op->buffer_var.get()] = scope; + PrintStorageScope(scope, stream); + PrintType(op->dtype, stream); stream << ' ' << vid << '[' << constant_size << "];\n"; @@ -882,10 +881,6 @@ void CodeGenC::VisitStmt_(const AttrStmtNode* op) { BindThreadIndex(iv); } } - } else if (op->attr_key == tir::attr::storage_scope) { - const VarNode* v = op->node.as(); - ICHECK(v); - alloc_storage_scope_[v] = op->value.as()->value; } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); ICHECK(v); diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index f844090ca6f5..2ed5fd4029a2 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -225,12 +225,9 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, Stmt body = SeqStmt::Flatten(reduce_body, assign_body); for (size_t idx = size; idx != 0; --idx) { body = Allocate(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); - body = AttrStmt(res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), body); if (!normal_red.empty()) { body = Allocate(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); - body = - AttrStmt(normal_res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), body); } } body = Substitute(body, value_map); diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index afd3c7add605..10d857bdc953 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -30,6 +30,8 @@ #include #include +#include "../transforms/ir_utils.h" + namespace tvm { namespace tir { @@ -58,11 +60,12 @@ class GPUCodeVerifier : public StmtExprVisitor { void VisitStmt_(const AllocateNode* op) final { StmtVisitor::VisitStmt_(op); + auto scope = GetPtrStorageScope(op->buffer_var); // visit an allocation of a buffer in shared memory, record its size - if (visited_local_buffers_.count(op->buffer_var.get()) != 0) { + if (scope == "local") { size_t size = static_cast(op->constant_allocation_size()); local_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); - } else if (visited_shared_buffers_.count(op->buffer_var.get()) != 0) { + } else if (scope == "shared") { size_t size = static_cast(op->constant_allocation_size()); shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); } @@ -78,15 +81,7 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - std::string op_value = op->value.as()->value; - if (op_value == "local") { - visited_local_buffers_.insert(op->node.as()); - } else if (op_value == "shared") { - visited_shared_buffers_.insert(op->node.as()); - } - StmtVisitor::VisitStmt_(op); - } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { if (nest_level_ == 0) { // enter a new kernel, reset statistics Reset_(); @@ -211,8 +206,6 @@ class GPUCodeVerifier : public StmtExprVisitor { private: int nest_level_{0}; - std::unordered_set visited_local_buffers_; - std::unordered_set visited_shared_buffers_; std::unordered_set visited_threads_; size_t thread_x_extent_, thread_y_extent_, thread_z_extent_; @@ -230,8 +223,6 @@ class GPUCodeVerifier : public StmtExprVisitor { std::vector errors_; void Reset_() { - visited_local_buffers_.clear(); - visited_shared_buffers_.clear(); local_memory_per_block_ = 0; shared_memory_per_block_ = 0; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 42ef60bb86d7..6fdeb30ec100 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -61,16 +61,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // AttrStmt AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { - if (attr_key == attr::storage_scope) { - const VarNode* buf = node.as(); - ICHECK(buf); - const auto* ptr_type = buf->type_annotation.as(); - ICHECK(ptr_type) << "The provided variable is not of pointer type"; - auto attr_scope = value.as()->value; - ICHECK_EQ(attr_scope, ptr_type->storage_scope) - << "Storage scopes attached to AttrStmt and buffer var are different. " << attr_scope - << ", " << ptr_type->storage_scope; - } auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 88c254a8cb5e..f1f914fa2f5c 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -130,7 +130,6 @@ class BufferFlattener : public StmtExprMutator { String storage_scope = buffer.scope(); PrimExpr area = BufferArea(buffer); body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), std::move(body)); - body = AttrStmt(buffer->data, attr::storage_scope, StringImm(storage_scope), std::move(body)); return body; } diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc index 40f0e368d93d..74538bcb6806 100644 --- a/src/tir/transforms/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -29,6 +29,7 @@ #include #include "../../arith/pattern_match.h" +#include "ir_utils.h" namespace tvm { namespace tir { @@ -42,10 +43,7 @@ class CopyIntrinInjector : public StmtMutator { flower_copy_fromto_(flower_copy_fromto) {} Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - storage_scope_[buf] = op->value.as()->value; - } else if (op->attr_key == pragma_key_) { + if (op->attr_key == pragma_key_) { Stmt ret; ICHECK(MatchCopyPattern(op->body, &ret)) << "Cannot match copy pattern of " << op->body; return ret; @@ -157,19 +155,12 @@ class CopyIntrinInjector : public StmtMutator { } // Get storage scope std::string GetStorageScope(const VarNode* var) const { - auto it = storage_scope_.find(var); - if (it != storage_scope_.end()) { - return it->second; - } else { - return ""; - } + return GetPtrStorageScope(GetRef(var)); } // pragma key std::string pragma_key_; // function to lower copy intrinsics. const PackedFunc& flower_copy_fromto_; - // Storage scope - std::unordered_map storage_scope_; // arith analyzer arith::Analyzer analyzer_; }; diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 7a16c06d8058..0b45bde28dfe 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -95,16 +95,7 @@ class DoubleBufferInjector : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - auto it = dbuffer_info_.find(buf); - if (it != dbuffer_info_.end()) { - it->second.scope = op->value.as()->value; - return this->VisitStmt(op->body); - } else { - return StmtExprMutator::VisitStmt_(op); - } - } else if (op->attr_key == attr::double_buffer_scope) { + if (op->attr_key == attr::double_buffer_scope) { return MakeProducer(op); } else { return StmtExprMutator::VisitStmt_(op); @@ -112,8 +103,10 @@ class DoubleBufferInjector : public StmtExprMutator { } Stmt VisitStmt_(const AllocateNode* op) final { - auto it = dbuffer_info_.find(op->buffer_var.get()); + const VarNode* buf = op->buffer_var.as(); + auto it = dbuffer_info_.find(buf); if (it != dbuffer_info_.end()) { + it->second.scope = GetPtrStorageScope(op->buffer_var); it->second.stride = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, make_const(DataType::Int(32), 1), op->extents) * op->dtype.lanes(); @@ -125,8 +118,6 @@ class DoubleBufferInjector : public StmtExprMutator { } ICHECK(it->second.loop != nullptr); auto& alloc_nest = loop_allocs_[it->second.loop]; - alloc_nest.emplace_back( - AttrStmt(op->buffer_var, attr::storage_scope, StringImm(it->second.scope), Evaluate(0))); alloc_nest.emplace_back( Allocate(op->buffer_var, op->dtype, new_extents, op->condition, Evaluate(0))); return op->body; diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index f7ece25d3fcd..b7348fe09fe2 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -172,16 +172,6 @@ class IRConvertSSA final : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { if (const VarNode* v = op->node.as()) { - if (op->attr_key == attr::storage_scope) { - const AllocateNode* alloc = op->body.as(); - if (alloc && op->node.same_as(alloc->buffer_var)) { - Stmt new_alloc = this->VisitStmt(op->body); - if (new_alloc.same_as(op->body)) return GetRef(op); - alloc = new_alloc.as(); - ICHECK(alloc); - return AttrStmt(alloc->buffer_var, op->attr_key, op->value, new_alloc); - } - } Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (scope_.count(v) && scope_[v].size() != 0) { diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index eafed837cee3..0893f02d7443 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -44,13 +44,13 @@ class StorageAccessInfoLower : public StmtExprMutator { // Lower allocate to device allocate when needed. Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - // For special memory, remove allocate, or use head expr - auto it = storage_info_.find(op->buffer_var.get()); - if (it != storage_info_.end() && it->second.info.defined()) { - const MemoryInfo& info = it->second.info; - ++it->second.alloc_count; - ICHECK_LE(it->second.alloc_count, 1) - << "Double allocation of " << it->second.scope.to_string(); + auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + ICHECK(storage_info_.find(op->buffer_var.get()) == storage_info_.end()) + << "Double allocation of " << scope.to_string(); + if (scope.tag.length() != 0) { + auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var)); + ICHECK(info.defined()) << "Cannot find memory info of " << scope.to_string(); + storage_info_[op->buffer_var.get()] = info; if (info->head_address.defined()) { return LetStmt(op->buffer_var, info->head_address, op->body); @@ -61,23 +61,6 @@ class StorageAccessInfoLower : public StmtExprMutator { return stmt; } } - Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - StorageScope scope = StorageScope::Create(op->value.as()->value); - StorageEntry e; - e.scope = scope; - if (scope.tag.length() != 0 && scope.tag != ".dyn") { - e.info = GetMemoryInfo(op->value.as()->value); - ICHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string(); - } - storage_info_[buf] = e; - return StmtExprMutator::VisitStmt_(op); - - } else { - return StmtExprMutator::VisitStmt_(op); - } - } PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_access_ptr())) { @@ -99,8 +82,8 @@ class StorageAccessInfoLower : public StmtExprMutator { Var buffer_var = Downcast(op->args[1]); PrimExpr offset = op->args[2]; auto it = storage_info_.find(buffer); - if (it != storage_info_.end() && it->second.info.defined()) { - return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset, it->second.info); + if (it != storage_info_.end() && it->second.defined()) { + return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset, it->second); } ICHECK(op->dtype.is_handle()); // Change to address_of @@ -118,17 +101,8 @@ class StorageAccessInfoLower : public StmtExprMutator { return cast(ptr_type, analyzer_.Simplify( offset / make_const(offset.dtype(), info->unit_bits / dtype_bits))); } - // The storage entry. - struct StorageEntry { - // Whether it is tagged memory. - StorageScope scope; - // The memory info if any. - MemoryInfo info; - // Allocation counter - int alloc_count{0}; - }; // The storage scope of each buffer - std::unordered_map storage_info_; + std::unordered_map storage_info_; // analyzer arith::Analyzer analyzer_; }; diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 25a2f4e060dd..481b1bfd4b19 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -53,8 +53,7 @@ class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScop // use volatile access to shared buffer. body = AttrStmt(remapped, attr::volatile_scope, 1, body); } - body = Allocate(remapped, op->dtype, op->extents, op->condition, body); - return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), body); + return Allocate(remapped, op->dtype, op->extents, op->condition, body); } return StmtExprMutator::VisitStmt_(op); } @@ -71,15 +70,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Stmt ret = StmtExprMutator::VisitStmt_(op); thread_extents_.pop_back(); return ret; - } else if (op->attr_key == attr::storage_scope) { - Stmt ret = StmtExprMutator::VisitStmt_(op); - op = ret.as(); - const VarNode* v = op->node.as(); - if (alloc_remap_.count(v)) { - return op->body; - } else { - return ret; - } } else if (op->attr_key == attr::reduce_scope) { const CommReducerNode* combiner = op->node.as(); ICHECK(combiner); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 060b02c3d137..8cc6d3f2541f 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -364,28 +364,15 @@ class WarpMemoryRewriter : private StmtMutator { Stmt VisitStmt_(const AllocateNode* op) { auto ret = StmtMutator::VisitStmt_(op); op = ret.as(); - if (warp_buffer_.count(op->buffer_var.get())) { + if (GetPtrStorageScope(op->buffer_var) == "warp") { + new_storage_scopes_[op->buffer_var.get()] = "local"; WarpAccessRewriter rewriter(warp_size_, &analyzer_); ret = rewriter.Rewrite(op); } return ret; } - Stmt VisitStmt_(const AttrStmtNode* op) { - using runtime::StorageScope; - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - StorageScope scope = StorageScope::Create(op->value.as()->value); - if (scope.rank == runtime::StorageRank::kWarp) { - warp_buffer_.insert(buf); - new_storage_scopes_[buf] = "local"; - } - } - return StmtMutator::VisitStmt_(op); - } - int warp_size_{0}; - std::unordered_set warp_buffer_; arith::Analyzer analyzer_; // variable domain std::unordered_map var_dom_; diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 5de22fe8665d..38b3a77b1a0c 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -224,7 +224,6 @@ class StorageFlattener : public StmtExprMutator { ret = Allocate(e.buffer->data, storage_type, shape, make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } - ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(skey.to_string()), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound, diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index b216b8b848db..03a01af85f3b 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -398,9 +398,7 @@ class StoragePlanRewriter : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - return this->VisitStmt(op->body); - } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || attr::IsPragmaKey(op->attr_key)) { // remake all the allocation at the attach scope. if (attach_map_.count(op)) { @@ -485,8 +483,6 @@ class StoragePlanRewriter : public StmtExprMutator { std::vector nest; for (StorageEntry* e : svec) { if (e->new_alloc.defined()) { - nest.emplace_back(AttrStmt(e->alloc_var, attr::storage_scope, - StringImm(e->scope.to_string()), Evaluate(0))); nest.push_back(e->new_alloc); } } diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index d0f58074ada0..1836b8ecec0d 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -69,7 +69,7 @@ class FragmentGetter : public StmtExprVisitor { ICHECK(k); ICHECK(layout); - std::string scope = scopes[buffer_var]; + std::string scope = GetPtrStorageScope(GetRef(buffer_var)); if (fragments.count(buffer_var)) { // check if the fragment has met before FragmentInfo info = fragments[buffer_var]; @@ -102,7 +102,7 @@ class FragmentGetter : public StmtExprVisitor { ICHECK(n); ICHECK(k); - std::string scope = scopes[buffer_var]; + std::string scope = GetPtrStorageScope(GetRef(buffer_var)); // Only wmma.accumulator can use tvm_fill_fragment ICHECK_EQ(scope, "wmma.accumulator"); if (fragments.count(buffer_var)) { @@ -119,16 +119,9 @@ class FragmentGetter : public StmtExprVisitor { // Get memory scope void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - const VarNode* buffer = op->node.as(); - ICHECK(buffer); - scopes[buffer] = op->value.as()->value; - } StmtExprVisitor::VisitStmt_(op); } - // Memory scope for allocations - std::unordered_map scopes; // Fragment metadata for all fragments std::unordered_map fragments; }; diff --git a/src/tir/transforms/update_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc index 0ae02fec9f95..4143577a0b17 100644 --- a/src/tir/transforms/update_pointer_storage_scope.cc +++ b/src/tir/transforms/update_pointer_storage_scope.cc @@ -64,17 +64,6 @@ PrimExpr UpdatePointerStorageScope::VisitExpr_(const LoadNode* op) { StmtExprMutator::VisitExpr(op->predicate)); } -Stmt UpdatePointerStorageScope::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - auto remapped = Downcast(StmtExprMutator::VisitExpr(GetRef(buf))); - auto new_scope = GetPtrStorageScope(remapped); - return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), - StmtMutator::VisitStmt(op->body)); - } - return StmtMutator::VisitStmt_(op); -} - Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition), diff --git a/src/tir/transforms/update_pointer_storage_scope.h b/src/tir/transforms/update_pointer_storage_scope.h index 481536a45b27..f310194a4a51 100644 --- a/src/tir/transforms/update_pointer_storage_scope.h +++ b/src/tir/transforms/update_pointer_storage_scope.h @@ -40,7 +40,6 @@ class UpdatePointerStorageScope : public StmtExprMutator { virtual PrimExpr VisitExpr_(const VarNode*); virtual PrimExpr VisitExpr_(const LoadNode*); - virtual Stmt VisitStmt_(const AttrStmtNode*); virtual Stmt VisitStmt_(const AllocateNode*); virtual Stmt VisitStmt_(const StoreNode*); diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index e148d4844a3e..2aeb6fee3158 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -31,8 +31,6 @@ def test_for(): A[j] = A[j] + 2 body = ib.get() - assert isinstance(body, tvm.tir.AttrStmt) - body = body.body assert isinstance(body, tvm.tir.Allocate) body = body.body assert isinstance(body, tvm.tir.For) diff --git a/tests/python/unittest/test_tir_transform_coproc_sync.py b/tests/python/unittest/test_tir_transform_coproc_sync.py index 2d45118f39f2..7dacd8e046cc 100644 --- a/tests/python/unittest/test_tir_transform_coproc_sync.py +++ b/tests/python/unittest/test_tir_transform_coproc_sync.py @@ -51,7 +51,7 @@ def meminfo_cache(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body - body = stmt.body.body.body + body = stmt.body.body blist = tvm.tir.stmt_list(body) assert blist[1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_read_barrier")) @@ -112,7 +112,7 @@ def __check_list(tvm_array, py_list): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body - slist = tvm.tir.stmt_list(stmt[0].body.body) + slist = tvm.tir.stmt_list(stmt[0].body) push_st = slist[2] slist = tvm.tir.stmt_list(slist[-1]) pop_st = slist[0].body[0] diff --git a/tests/python/unittest/test_tir_transform_inject_double_buffer.py b/tests/python/unittest/test_tir_transform_inject_double_buffer.py index ceb32c484c6d..9b37bcaaacbc 100644 --- a/tests/python/unittest/test_tir_transform_inject_double_buffer.py +++ b/tests/python/unittest/test_tir_transform_inject_double_buffer.py @@ -47,8 +47,8 @@ def test_double_buffer(): mod = opt(mod) stmt = mod["db"].body - assert isinstance(stmt.body.body, tvm.tir.Allocate) - assert stmt.body.body.extents[0].value == 2 + assert isinstance(stmt.body, tvm.tir.Allocate) + assert stmt.body.extents[0].value == 2 f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] count = [0] diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index 3e7a5a0cb300..673267a9b1fa 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -49,13 +49,13 @@ def get_vthread(name): stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("vthread"))) - )["main"].body + )["main"] assert stmt.body.body.extents[0].value == 2 stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread"))) - )["main"].body + )["main"] assert len(stmt.body.body.extents) == 3 @@ -94,11 +94,11 @@ def get_vthread(name): stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread"))) - )["main"].body + )["main"] assert stmt.body.body.extents[0].value == 2 - assert stmt.body.body.body.body.body.body.extents[0].value == 2 - assert len(stmt.body.body.body.body.body.body.extents) == 3 + assert stmt.body.body.body.body.extents[0].value == 2 + assert len(stmt.body.body.body.body.extents) == 3 def test_vthread_if_then_else(): @@ -119,7 +119,7 @@ def test_vthread_if_then_else(): stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) - )["main"].body + )["main"] assert stmt.body.body.body[0].else_case != None assert stmt.body.body.body[1].else_case == None diff --git a/tests/python/unittest/test_tir_transform_lift_attr_scope.py b/tests/python/unittest/test_tir_transform_lift_attr_scope.py index 12ad16dfe092..65e317dfbcb8 100644 --- a/tests/python/unittest/test_tir_transform_lift_attr_scope.py +++ b/tests/python/unittest/test_tir_transform_lift_attr_scope.py @@ -38,7 +38,7 @@ def test_coproc_lift(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body + body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"] assert body.body.body.node == cp @@ -58,7 +58,7 @@ def test_coproc_lift(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body + body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"] assert body.body.body.body[1].node == cp assert len(body.body.body.body) == 2 diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index 6194024748e0..c632f744bb81 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -178,7 +178,7 @@ def test_vectorize(): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) s[C].vectorize(x) - stmt = tvm.lower(s, [A, B], name="main")["main"].body + stmt = tvm.lower(s, [A, B], name="main")["main"] body = stmt.body.body.body.body assert x.var.name not in str(body.condition) assert any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp))) @@ -229,7 +229,7 @@ def test_thread_axis2(): _, x = s[C].split(x, factor=m) s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) - stmt = tvm.lower(s, [A, B], name="main")["main"].body + stmt = tvm.lower(s, [A, B], name="main")["main"] for_body = stmt.body.body.body.body[0] assert "threadIdx" not in str(for_body.extent) diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index f3baff120cf6..84bf0c4d52fd 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -47,8 +47,9 @@ def test_lower_warp_memory_local_scope(): fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"] mod = tvm.IRModule.from_expr(fdevice) fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"] - assert fdevice.body.body.value.value == "local" - assert fdevice.body.body.body.extents[0].value == 2 + allocate = fdevice.body.body + assert allocate.buffer_var.type_annotation.storage_scope == "local" + assert fdevice.body.body.extents[0].value == 2 @tvm.testing.requires_cuda diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 2d1fea01aa32..0e9ab862a9c8 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -79,7 +79,7 @@ def test_flatten_storage_align(): )(mod) stmt = mod["main"].body - assert stmt.body.extents[0].value == 17 * 8 + assert stmt.extents[0].value == 17 * 8 def test_flatten_double_buffer(): @@ -114,8 +114,8 @@ def test_flatten_double_buffer(): )(mod) stmt = mod["main"].body - assert isinstance(stmt.body.body, tvm.tir.Allocate) - assert stmt.body.body.extents[0].value == 2 + assert isinstance(stmt.body, tvm.tir.Allocate) + assert stmt.body.extents[0].value == 2 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C], stmt).with_attr("global_symbol", "db")) f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 70e77ff69fea..9e738b136b17 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -298,9 +298,9 @@ def test_storage_share_gpu(): alloc_stats = {"global": 0, "shared": 0} def verify(n): - if isinstance(n, tvm.tir.AttrStmt): - if n.attr_key == "storage_scope": - alloc_stats[n.value.value] += 1 + if isinstance(n, tvm.tir.Allocate): + scope = n.buffer_var.type_annotation.storage_scope + alloc_stats[scope] += 1 tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert alloc_stats["global"] == 2 @@ -317,7 +317,7 @@ def test_parallel_alloc(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + body = tvm.tir.transform.StorageRewrite()(mod)["main"] assert isinstance(body.body.body, tvm.tir.Allocate) @@ -334,7 +334,7 @@ def test_parallel_alloc(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + body = tvm.tir.transform.StorageRewrite()(mod)["main"] assert isinstance(body.body.body.body.body, tvm.tir.Allocate) @@ -356,7 +356,6 @@ def get_mod(kind="serial"): mod = get_mod(kind="parallel") # parallel (i, 0, n) { - # // attr [j] storage_scope = "global" # allocate j[int32 * 1] # j[0] = 0 # while((j[0] < 10)){ @@ -366,11 +365,9 @@ def get_mod(kind="serial"): # j[0] = (j[0] + (j[0] + 1)) # } # } - body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + body = tvm.tir.transform.StorageRewrite()(mod)["main"] # parallel (i, 0, n) { - # // attr [j] storage_scope = "global" # allocate j[int32 * 1] - # // attr [A] storage_scope = "global" # allocate A[float32 * n] # j[0] = 0 # while((j[0] < 10)){ @@ -379,11 +376,10 @@ def get_mod(kind="serial"): # } # } assert isinstance(body.body.body, tvm.tir.Allocate) # j - assert isinstance(body.body.body.body.body, tvm.tir.Allocate) # A + assert isinstance(body.body.body.body, tvm.tir.Allocate) # A mod = get_mod(kind="serial") # for (i, 0, n) { - # // attr [j] storage_scope = "global" # allocate j[int32 * 1] # j[0] = 0 # while((j[0] < 10)){ @@ -393,10 +389,8 @@ def get_mod(kind="serial"): # j[0] = (j[0] + (j[0] + 1)) # } # } - body = tvm.tir.transform.StorageRewrite()(mod)["main"].body - # // attr [j] storage_scope = "global" + body = tvm.tir.transform.StorageRewrite()(mod)["main"] # allocate j[int32 * 1] - # // attr [A] storage_scope = "global" # allocate A[float32 * n] # for (i, 0, n) { # j[0] = 0 @@ -406,7 +400,7 @@ def get_mod(kind="serial"): # } # } assert isinstance(body.body, tvm.tir.Allocate) # j - assert isinstance(body.body.body.body, tvm.tir.Allocate) # A + assert isinstance(body.body.body, tvm.tir.Allocate) # A def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024): diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 030c01713927..7fff6a804e4a 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -52,7 +52,7 @@ def test_thread_storage_sync(): mod = tvm.IRModule.from_expr(fdevice) cuda_target = tvm.target.Target("cuda") f = tvm.tir.transform.ThreadSync("shared")(mod)["test_kernel0"] - body_list = tvm.tir.stmt_list(f.body.body.body.body) + body_list = tvm.tir.stmt_list(f.body.body.body) assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")) From c95ede5e6398d8c729c0c26d7e2939a441c578e7 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 27 Jul 2021 23:21:46 +0900 Subject: [PATCH 08/21] fix allocate location --- ...merge_dynamic_shared_memory_allocations.cc | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index 25bf730a7739..83f83b89c72c 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -58,20 +58,26 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { DynamicSharedMemoryRewriter(const std::unordered_set& dyn_shmem_allocs) : dyn_shmem_allocs_{dyn_shmem_allocs} {} - Stmt Rewrite(Stmt stmt) { - int align = 1; - for (auto& alloc : dyn_shmem_allocs_) { - align = std::max(align, alloc->dtype.bytes()); + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::thread_extent && !allocated) { + // Allocate one dynamic shared memory allocation at the beginning of thread scope + int align = 1; + for (auto& alloc : dyn_shmem_allocs_) { + align = std::max(align, alloc->dtype.bytes()); + } + for (auto& alloc : dyn_shmem_allocs_) { + buffer_offsets_[alloc->buffer_var.get()] = merged_alloc_size_; + merged_alloc_size_ += alloc->extents[0] * align; + LOG(INFO) << "buffer offset for " << alloc->buffer_var->name_hint << " = " + << buffer_offsets_[alloc->buffer_var.get()]; + } + + allocated = true; + auto new_body = Allocate(merged_buf_var_, DataType::UInt(8), {merged_alloc_size_}, + const_true(), StmtExprMutator::VisitStmt(op->body)); + return AttrStmt(op->node, op->attr_key, op->value, new_body, op->span); } - for (auto& alloc : dyn_shmem_allocs_) { - buffer_offsets_[alloc->buffer_var.get()] = merged_alloc_size_; - merged_alloc_size_ += alloc->extents[0] * align; - LOG(INFO) << "buffer offset for " << alloc->buffer_var->name_hint << " = " - << buffer_offsets_[alloc->buffer_var.get()]; - } - - return Allocate(merged_buf_var_, DataType::UInt(8), {merged_alloc_size_}, const_true(), - StmtExprMutator::VisitStmt(stmt)); + return StmtMutator::VisitStmt_(op); } Stmt VisitStmt_(const AllocateNode* op) final { @@ -111,12 +117,13 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { std::unordered_set dyn_shmem_allocs_; PrimExpr merged_alloc_size_{0}; std::unordered_map buffer_offsets_; + bool allocated{false}; }; Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) { AllocateCollector collector; collector(stmt); - return DynamicSharedMemoryRewriter(collector.dyn_shmem_allocs_).Rewrite(std::move(stmt)); + return DynamicSharedMemoryRewriter(collector.dyn_shmem_allocs_)(std::move(stmt)); } namespace transform { From d55ce65768ec530a0aa7ac14ea7c73364ee7af9b Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 27 Jul 2021 23:47:54 +0900 Subject: [PATCH 09/21] Revert "Remove all attr::storage_scope usage" This reverts commit ce62d9e53ea467bd33d1485006aba97a02719712. --- include/tvm/tir/stmt.h | 2 + python/tvm/contrib/hexagon.py | 16 +++---- python/tvm/script/scope_handler.py | 2 +- python/tvm/tir/ir_builder.py | 2 + src/printer/tvmscript_printer.cc | 48 +++++++++++-------- src/relay/backend/aot_executor_codegen.cc | 2 + src/target/source/codegen_c.cc | 15 ++++-- src/te/operation/cross_thread_reduction.cc | 3 ++ src/tir/analysis/verify_gpu_code.cc | 21 +++++--- src/tir/ir/stmt.cc | 10 ++++ src/tir/transforms/flatten_buffer.cc | 1 + src/tir/transforms/inject_copy_intrin.cc | 15 ++++-- src/tir/transforms/inject_double_buffer.cc | 17 +++++-- src/tir/transforms/ir_utils.cc | 10 ++++ .../lower_device_storage_access_info.cc | 46 ++++++++++++++---- src/tir/transforms/lower_thread_allreduce.cc | 12 ++++- src/tir/transforms/lower_warp_memory.cc | 17 ++++++- src/tir/transforms/storage_flatten.cc | 1 + src/tir/transforms/storage_rewrite.cc | 6 ++- .../transforms/tensorcore_infer_fragment.cc | 11 ++++- .../update_pointer_storage_scope.cc | 11 +++++ .../transforms/update_pointer_storage_scope.h | 1 + tests/python/unittest/test_tir_ir_builder.py | 2 + .../test_tir_transform_coproc_sync.py | 4 +- ...test_tir_transform_inject_double_buffer.py | 4 +- ...est_tir_transform_inject_virtual_thread.py | 12 ++--- .../test_tir_transform_lift_attr_scope.py | 4 +- .../test_tir_transform_loop_partition.py | 4 +- .../test_tir_transform_lower_warp_memory.py | 5 +- .../test_tir_transform_storage_flatten.py | 6 +-- .../test_tir_transform_storage_rewrite.py | 24 ++++++---- .../test_tir_transform_thread_sync.py | 2 +- 32 files changed, 240 insertions(+), 96 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index c41cac2a3a25..9997a4d95694 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1240,6 +1240,8 @@ constexpr const char* extern_scope = "extern_scope"; * This can hint some code generator to create a new function for compute. */ constexpr const char* compute_scope = "compute_scope"; +/*! \brief Mark storage scope of buffers */ +constexpr const char* storage_scope = "storage_scope"; /*! \brief Mark storage alignement requirement of buffers */ constexpr const char* storage_alignment = "storage_alignment"; /*! \brief Mark storage scope of realization */ diff --git a/python/tvm/contrib/hexagon.py b/python/tvm/contrib/hexagon.py index c2197af22d2a..34b37537776f 100644 --- a/python/tvm/contrib/hexagon.py +++ b/python/tvm/contrib/hexagon.py @@ -176,27 +176,23 @@ def buf_align(var): def visit(stmt): """Collect information about VTCM buffers and their alignments.""" if isinstance(stmt, tvm.tir.AttrStmt): - if stmt.attr_key == "storage_alignment": + if stmt.attr_key == "storage_scope" and stmt.value == "local.vtcm": + vtcm_buffers.append(stmt.node) + elif stmt.attr_key == "storage_alignment": if not stmt.node in alignments: alignments[stmt.node] = [] alignments[stmt.node].append(stmt.value) - elif isinstance(stmt, tvm.tir.Allocate): - scope = stmt.buffer_var.type_annotation.storage_scope - if scope == "local.vtcm": - vtcm_buffers.append(stmt.node) - def mutate(stmt): """Insert calls to VTCM allocation and deallocation routines.""" if isinstance(stmt, tvm.tir.AttrStmt): - if stmt.attr_key == "storage_alignment": + if stmt.attr_key == "storage_scope" and stmt.value == "local.vtcm": + vtcm_buffers.pop() + elif stmt.attr_key == "storage_alignment": alignments[stmt.node].pop() return stmt if isinstance(stmt, tvm.tir.Allocate): var = stmt.buffer_var - scope = var.type_annotation.storage_scope - if scope == "local.vtcm": - vtcm_buffers.pop() if var in vtcm_buffers: is_null = tvm.tir.call_intrin("bool", tvm.ir.Op.get("tir.isnullptr"), var) throw_error = tvm.tir.call_intrin( diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py index 971580343763..d07209485bd4 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/scope_handler.py @@ -113,7 +113,7 @@ def allocate(extents, dtype, scope, condition=True, span=None): body = tvm.tir.Allocate( self.buffer_var, dtype, extents, condition, self.body, span=span ) - return body + return tvm.tir.AttrStmt(self.buffer_var, "storage_scope", scope, body, span=span) super().__init__(allocate, concise_scope=True, def_symbol=True) self.buffer_var = None diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 978c630b17ad..35932540fe68 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -419,6 +419,8 @@ def allocate(self, dtype, shape, name="buf", scope=""): buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope)) if not isinstance(shape, (list, tuple, _container.Array)): shape = [shape] + if scope: + self.scope_attr(buffer_var, "storage_scope", scope) self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x)) return BufferVar(self, buffer_var, shape, dtype) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index cc8aa48f3cd1..01f79bd0c750 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -37,7 +37,6 @@ #include #include -#include "../tir/transforms/ir_utils.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" @@ -580,6 +579,31 @@ Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) { Doc TVMScriptPrinter::VisitStmt_(const AttrStmtNode* op) { Doc doc; + // merge attr with allocate when possible + if (op->node->IsInstance() && op->attr_key == "storage_scope" && + op->body->IsInstance()) { + const auto* alloc = Downcast(op->body).get(); + if (alloc->buffer_var.same_as(op->node)) { + var_not_in_headers.insert(alloc->buffer_var.get()); + if (current_num_ != num_child_ - 1) { + doc << "with tir.allocate(" << Print(alloc->extents) << ", " << PrintDType(alloc->dtype) + << ", " << Print(op->value); + if (!is_one(alloc->condition)) { + doc << ", " << Print(alloc->condition); + } + doc << ") as " << Print(op->node) << ":"; + doc << Doc::Indent(4, Doc::NewLine() << PrintBody(alloc->body)); + } else { + doc << Print(op->node) << " = tir.allocate(" << Print(alloc->extents) << ", " + << PrintDType(alloc->dtype) << ", " << Print(op->value); + if (!is_one(alloc->condition)) { + doc << ", " << Print(alloc->condition); + } + doc << ")" << Doc::NewLine() << PrintBody(alloc->body); + } + return doc; + } + } // merge attr with realize when possible if (op->node->IsInstance() && op->attr_key == "realize_scope" && op->body->IsInstance()) { @@ -657,26 +681,8 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) { } Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { - var_not_in_headers.insert(op->buffer_var.get()); - Doc doc; - auto storage_scope = GetPtrStorageScope(op->buffer_var); - if (current_num_ != num_child_ - 1) { - doc << "with tir.allocate(" << Print(op->extents) << ", " << PrintDType(op->dtype) << ", " - << Print(storage_scope); - if (!is_one(op->condition)) { - doc << ", " << Print(op->condition); - } - doc << ") as " << Print(op->buffer_var) << ":"; - doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); - } else { - doc << Print(op->buffer_var) << " = tir.allocate(" << Print(op->extents) << ", " - << PrintDType(op->dtype) << ", " << Print(storage_scope); - if (!is_one(op->condition)) { - doc << ", " << Print(op->condition); - } - doc << ")" << Doc::NewLine() << PrintBody(op->body); - } - return doc; + LOG(FATAL) << "TVM Script Printer Internal Error: All the Allocate should be folded with Attr"; + return Doc(); } Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) { diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index fd6ee27eb6be..4df38b9449ae 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -625,6 +625,8 @@ class AOTExecutorCodegen : public ExprVisitor { // so we don't pay the price of allocation for every inference if (!allocated[sid]) { body = tir::Allocate(sids_table_[sid], DataType::Int(8), {size}, tir::const_true(), body); + body = tir::AttrStmt(sids_table_[sid], tir::attr::storage_scope, tir::StringImm("global"), + body); } allocated[sid] = true; } diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 8397044e8b93..99c9452975d4 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -861,11 +861,12 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) { this->PrintIndent(); int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; - - auto scope = GetPtrStorageScope(op->buffer_var); - alloc_storage_scope_[op->buffer_var.get()] = scope; - PrintStorageScope(scope, stream); - + const VarNode* buffer = op->buffer_var.as(); + auto it = alloc_storage_scope_.find(buffer); + if (it != alloc_storage_scope_.end()) { + std::string scope = alloc_storage_scope_.at(buffer); + PrintStorageScope(scope, stream); + } PrintType(op->dtype, stream); stream << ' ' << vid << '[' << constant_size << "];\n"; @@ -881,6 +882,10 @@ void CodeGenC::VisitStmt_(const AttrStmtNode* op) { BindThreadIndex(iv); } } + } else if (op->attr_key == tir::attr::storage_scope) { + const VarNode* v = op->node.as(); + ICHECK(v); + alloc_storage_scope_[v] = op->value.as()->value; } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); ICHECK(v); diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index 2ed5fd4029a2..f844090ca6f5 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -225,9 +225,12 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, Stmt body = SeqStmt::Flatten(reduce_body, assign_body); for (size_t idx = size; idx != 0; --idx) { body = Allocate(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); + body = AttrStmt(res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), body); if (!normal_red.empty()) { body = Allocate(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); + body = + AttrStmt(normal_res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), body); } } body = Substitute(body, value_map); diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index 10d857bdc953..afd3c7add605 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -30,8 +30,6 @@ #include #include -#include "../transforms/ir_utils.h" - namespace tvm { namespace tir { @@ -60,12 +58,11 @@ class GPUCodeVerifier : public StmtExprVisitor { void VisitStmt_(const AllocateNode* op) final { StmtVisitor::VisitStmt_(op); - auto scope = GetPtrStorageScope(op->buffer_var); // visit an allocation of a buffer in shared memory, record its size - if (scope == "local") { + if (visited_local_buffers_.count(op->buffer_var.get()) != 0) { size_t size = static_cast(op->constant_allocation_size()); local_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); - } else if (scope == "shared") { + } else if (visited_shared_buffers_.count(op->buffer_var.get()) != 0) { size_t size = static_cast(op->constant_allocation_size()); shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); } @@ -81,7 +78,15 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::storage_scope) { + std::string op_value = op->value.as()->value; + if (op_value == "local") { + visited_local_buffers_.insert(op->node.as()); + } else if (op_value == "shared") { + visited_shared_buffers_.insert(op->node.as()); + } + StmtVisitor::VisitStmt_(op); + } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { if (nest_level_ == 0) { // enter a new kernel, reset statistics Reset_(); @@ -206,6 +211,8 @@ class GPUCodeVerifier : public StmtExprVisitor { private: int nest_level_{0}; + std::unordered_set visited_local_buffers_; + std::unordered_set visited_shared_buffers_; std::unordered_set visited_threads_; size_t thread_x_extent_, thread_y_extent_, thread_z_extent_; @@ -223,6 +230,8 @@ class GPUCodeVerifier : public StmtExprVisitor { std::vector errors_; void Reset_() { + visited_local_buffers_.clear(); + visited_shared_buffers_.clear(); local_memory_per_block_ = 0; shared_memory_per_block_ = 0; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 6fdeb30ec100..42ef60bb86d7 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -61,6 +61,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // AttrStmt AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { + if (attr_key == attr::storage_scope) { + const VarNode* buf = node.as(); + ICHECK(buf); + const auto* ptr_type = buf->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + auto attr_scope = value.as()->value; + ICHECK_EQ(attr_scope, ptr_type->storage_scope) + << "Storage scopes attached to AttrStmt and buffer var are different. " << attr_scope + << ", " << ptr_type->storage_scope; + } auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index f1f914fa2f5c..88c254a8cb5e 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -130,6 +130,7 @@ class BufferFlattener : public StmtExprMutator { String storage_scope = buffer.scope(); PrimExpr area = BufferArea(buffer); body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), std::move(body)); + body = AttrStmt(buffer->data, attr::storage_scope, StringImm(storage_scope), std::move(body)); return body; } diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc index 74538bcb6806..40f0e368d93d 100644 --- a/src/tir/transforms/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -29,7 +29,6 @@ #include #include "../../arith/pattern_match.h" -#include "ir_utils.h" namespace tvm { namespace tir { @@ -43,7 +42,10 @@ class CopyIntrinInjector : public StmtMutator { flower_copy_fromto_(flower_copy_fromto) {} Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == pragma_key_) { + if (op->attr_key == attr::storage_scope) { + const VarNode* buf = op->node.as(); + storage_scope_[buf] = op->value.as()->value; + } else if (op->attr_key == pragma_key_) { Stmt ret; ICHECK(MatchCopyPattern(op->body, &ret)) << "Cannot match copy pattern of " << op->body; return ret; @@ -155,12 +157,19 @@ class CopyIntrinInjector : public StmtMutator { } // Get storage scope std::string GetStorageScope(const VarNode* var) const { - return GetPtrStorageScope(GetRef(var)); + auto it = storage_scope_.find(var); + if (it != storage_scope_.end()) { + return it->second; + } else { + return ""; + } } // pragma key std::string pragma_key_; // function to lower copy intrinsics. const PackedFunc& flower_copy_fromto_; + // Storage scope + std::unordered_map storage_scope_; // arith analyzer arith::Analyzer analyzer_; }; diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 0b45bde28dfe..7a16c06d8058 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -95,7 +95,16 @@ class DoubleBufferInjector : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::double_buffer_scope) { + if (op->attr_key == attr::storage_scope) { + const VarNode* buf = op->node.as(); + auto it = dbuffer_info_.find(buf); + if (it != dbuffer_info_.end()) { + it->second.scope = op->value.as()->value; + return this->VisitStmt(op->body); + } else { + return StmtExprMutator::VisitStmt_(op); + } + } else if (op->attr_key == attr::double_buffer_scope) { return MakeProducer(op); } else { return StmtExprMutator::VisitStmt_(op); @@ -103,10 +112,8 @@ class DoubleBufferInjector : public StmtExprMutator { } Stmt VisitStmt_(const AllocateNode* op) final { - const VarNode* buf = op->buffer_var.as(); - auto it = dbuffer_info_.find(buf); + auto it = dbuffer_info_.find(op->buffer_var.get()); if (it != dbuffer_info_.end()) { - it->second.scope = GetPtrStorageScope(op->buffer_var); it->second.stride = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, make_const(DataType::Int(32), 1), op->extents) * op->dtype.lanes(); @@ -118,6 +125,8 @@ class DoubleBufferInjector : public StmtExprMutator { } ICHECK(it->second.loop != nullptr); auto& alloc_nest = loop_allocs_[it->second.loop]; + alloc_nest.emplace_back( + AttrStmt(op->buffer_var, attr::storage_scope, StringImm(it->second.scope), Evaluate(0))); alloc_nest.emplace_back( Allocate(op->buffer_var, op->dtype, new_extents, op->condition, Evaluate(0))); return op->body; diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index b7348fe09fe2..f7ece25d3fcd 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -172,6 +172,16 @@ class IRConvertSSA final : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { if (const VarNode* v = op->node.as()) { + if (op->attr_key == attr::storage_scope) { + const AllocateNode* alloc = op->body.as(); + if (alloc && op->node.same_as(alloc->buffer_var)) { + Stmt new_alloc = this->VisitStmt(op->body); + if (new_alloc.same_as(op->body)) return GetRef(op); + alloc = new_alloc.as(); + ICHECK(alloc); + return AttrStmt(alloc->buffer_var, op->attr_key, op->value, new_alloc); + } + } Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (scope_.count(v) && scope_[v].size() != 0) { diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index 0893f02d7443..eafed837cee3 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -44,13 +44,13 @@ class StorageAccessInfoLower : public StmtExprMutator { // Lower allocate to device allocate when needed. Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var)); - ICHECK(storage_info_.find(op->buffer_var.get()) == storage_info_.end()) - << "Double allocation of " << scope.to_string(); - if (scope.tag.length() != 0) { - auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var)); - ICHECK(info.defined()) << "Cannot find memory info of " << scope.to_string(); - storage_info_[op->buffer_var.get()] = info; + // For special memory, remove allocate, or use head expr + auto it = storage_info_.find(op->buffer_var.get()); + if (it != storage_info_.end() && it->second.info.defined()) { + const MemoryInfo& info = it->second.info; + ++it->second.alloc_count; + ICHECK_LE(it->second.alloc_count, 1) + << "Double allocation of " << it->second.scope.to_string(); if (info->head_address.defined()) { return LetStmt(op->buffer_var, info->head_address, op->body); @@ -61,6 +61,23 @@ class StorageAccessInfoLower : public StmtExprMutator { return stmt; } } + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::storage_scope) { + const VarNode* buf = op->node.as(); + StorageScope scope = StorageScope::Create(op->value.as()->value); + StorageEntry e; + e.scope = scope; + if (scope.tag.length() != 0 && scope.tag != ".dyn") { + e.info = GetMemoryInfo(op->value.as()->value); + ICHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string(); + } + storage_info_[buf] = e; + return StmtExprMutator::VisitStmt_(op); + + } else { + return StmtExprMutator::VisitStmt_(op); + } + } PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::tvm_access_ptr())) { @@ -82,8 +99,8 @@ class StorageAccessInfoLower : public StmtExprMutator { Var buffer_var = Downcast(op->args[1]); PrimExpr offset = op->args[2]; auto it = storage_info_.find(buffer); - if (it != storage_info_.end() && it->second.defined()) { - return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset, it->second); + if (it != storage_info_.end() && it->second.info.defined()) { + return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset, it->second.info); } ICHECK(op->dtype.is_handle()); // Change to address_of @@ -101,8 +118,17 @@ class StorageAccessInfoLower : public StmtExprMutator { return cast(ptr_type, analyzer_.Simplify( offset / make_const(offset.dtype(), info->unit_bits / dtype_bits))); } + // The storage entry. + struct StorageEntry { + // Whether it is tagged memory. + StorageScope scope; + // The memory info if any. + MemoryInfo info; + // Allocation counter + int alloc_count{0}; + }; // The storage scope of each buffer - std::unordered_map storage_info_; + std::unordered_map storage_info_; // analyzer arith::Analyzer analyzer_; }; diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 481b1bfd4b19..25a2f4e060dd 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -53,7 +53,8 @@ class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScop // use volatile access to shared buffer. body = AttrStmt(remapped, attr::volatile_scope, 1, body); } - return Allocate(remapped, op->dtype, op->extents, op->condition, body); + body = Allocate(remapped, op->dtype, op->extents, op->condition, body); + return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), body); } return StmtExprMutator::VisitStmt_(op); } @@ -70,6 +71,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Stmt ret = StmtExprMutator::VisitStmt_(op); thread_extents_.pop_back(); return ret; + } else if (op->attr_key == attr::storage_scope) { + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + const VarNode* v = op->node.as(); + if (alloc_remap_.count(v)) { + return op->body; + } else { + return ret; + } } else if (op->attr_key == attr::reduce_scope) { const CommReducerNode* combiner = op->node.as(); ICHECK(combiner); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 8cc6d3f2541f..060b02c3d137 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -364,15 +364,28 @@ class WarpMemoryRewriter : private StmtMutator { Stmt VisitStmt_(const AllocateNode* op) { auto ret = StmtMutator::VisitStmt_(op); op = ret.as(); - if (GetPtrStorageScope(op->buffer_var) == "warp") { - new_storage_scopes_[op->buffer_var.get()] = "local"; + if (warp_buffer_.count(op->buffer_var.get())) { WarpAccessRewriter rewriter(warp_size_, &analyzer_); ret = rewriter.Rewrite(op); } return ret; } + Stmt VisitStmt_(const AttrStmtNode* op) { + using runtime::StorageScope; + if (op->attr_key == attr::storage_scope) { + const VarNode* buf = op->node.as(); + StorageScope scope = StorageScope::Create(op->value.as()->value); + if (scope.rank == runtime::StorageRank::kWarp) { + warp_buffer_.insert(buf); + new_storage_scopes_[buf] = "local"; + } + } + return StmtMutator::VisitStmt_(op); + } + int warp_size_{0}; + std::unordered_set warp_buffer_; arith::Analyzer analyzer_; // variable domain std::unordered_map var_dom_; diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 38b3a77b1a0c..5de22fe8665d 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -224,6 +224,7 @@ class StorageFlattener : public StmtExprMutator { ret = Allocate(e.buffer->data, storage_type, shape, make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } + ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(skey.to_string()), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound, diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 03a01af85f3b..b216b8b848db 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -398,7 +398,9 @@ class StoragePlanRewriter : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || + if (op->attr_key == attr::storage_scope) { + return this->VisitStmt(op->body); + } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || attr::IsPragmaKey(op->attr_key)) { // remake all the allocation at the attach scope. if (attach_map_.count(op)) { @@ -483,6 +485,8 @@ class StoragePlanRewriter : public StmtExprMutator { std::vector nest; for (StorageEntry* e : svec) { if (e->new_alloc.defined()) { + nest.emplace_back(AttrStmt(e->alloc_var, attr::storage_scope, + StringImm(e->scope.to_string()), Evaluate(0))); nest.push_back(e->new_alloc); } } diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index 1836b8ecec0d..d0f58074ada0 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -69,7 +69,7 @@ class FragmentGetter : public StmtExprVisitor { ICHECK(k); ICHECK(layout); - std::string scope = GetPtrStorageScope(GetRef(buffer_var)); + std::string scope = scopes[buffer_var]; if (fragments.count(buffer_var)) { // check if the fragment has met before FragmentInfo info = fragments[buffer_var]; @@ -102,7 +102,7 @@ class FragmentGetter : public StmtExprVisitor { ICHECK(n); ICHECK(k); - std::string scope = GetPtrStorageScope(GetRef(buffer_var)); + std::string scope = scopes[buffer_var]; // Only wmma.accumulator can use tvm_fill_fragment ICHECK_EQ(scope, "wmma.accumulator"); if (fragments.count(buffer_var)) { @@ -119,9 +119,16 @@ class FragmentGetter : public StmtExprVisitor { // Get memory scope void VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::storage_scope) { + const VarNode* buffer = op->node.as(); + ICHECK(buffer); + scopes[buffer] = op->value.as()->value; + } StmtExprVisitor::VisitStmt_(op); } + // Memory scope for allocations + std::unordered_map scopes; // Fragment metadata for all fragments std::unordered_map fragments; }; diff --git a/src/tir/transforms/update_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc index 4143577a0b17..0ae02fec9f95 100644 --- a/src/tir/transforms/update_pointer_storage_scope.cc +++ b/src/tir/transforms/update_pointer_storage_scope.cc @@ -64,6 +64,17 @@ PrimExpr UpdatePointerStorageScope::VisitExpr_(const LoadNode* op) { StmtExprMutator::VisitExpr(op->predicate)); } +Stmt UpdatePointerStorageScope::VisitStmt_(const AttrStmtNode* op) { + if (op->attr_key == attr::storage_scope) { + const VarNode* buf = op->node.as(); + auto remapped = Downcast(StmtExprMutator::VisitExpr(GetRef(buf))); + auto new_scope = GetPtrStorageScope(remapped); + return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), + StmtMutator::VisitStmt(op->body)); + } + return StmtMutator::VisitStmt_(op); +} + Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition), diff --git a/src/tir/transforms/update_pointer_storage_scope.h b/src/tir/transforms/update_pointer_storage_scope.h index f310194a4a51..481536a45b27 100644 --- a/src/tir/transforms/update_pointer_storage_scope.h +++ b/src/tir/transforms/update_pointer_storage_scope.h @@ -40,6 +40,7 @@ class UpdatePointerStorageScope : public StmtExprMutator { virtual PrimExpr VisitExpr_(const VarNode*); virtual PrimExpr VisitExpr_(const LoadNode*); + virtual Stmt VisitStmt_(const AttrStmtNode*); virtual Stmt VisitStmt_(const AllocateNode*); virtual Stmt VisitStmt_(const StoreNode*); diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 2aeb6fee3158..e148d4844a3e 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -31,6 +31,8 @@ def test_for(): A[j] = A[j] + 2 body = ib.get() + assert isinstance(body, tvm.tir.AttrStmt) + body = body.body assert isinstance(body, tvm.tir.Allocate) body = body.body assert isinstance(body, tvm.tir.For) diff --git a/tests/python/unittest/test_tir_transform_coproc_sync.py b/tests/python/unittest/test_tir_transform_coproc_sync.py index 7dacd8e046cc..2d45118f39f2 100644 --- a/tests/python/unittest/test_tir_transform_coproc_sync.py +++ b/tests/python/unittest/test_tir_transform_coproc_sync.py @@ -51,7 +51,7 @@ def meminfo_cache(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body - body = stmt.body.body + body = stmt.body.body.body blist = tvm.tir.stmt_list(body) assert blist[1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_read_barrier")) @@ -112,7 +112,7 @@ def __check_list(tvm_array, py_list): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body - slist = tvm.tir.stmt_list(stmt[0].body) + slist = tvm.tir.stmt_list(stmt[0].body.body) push_st = slist[2] slist = tvm.tir.stmt_list(slist[-1]) pop_st = slist[0].body[0] diff --git a/tests/python/unittest/test_tir_transform_inject_double_buffer.py b/tests/python/unittest/test_tir_transform_inject_double_buffer.py index 9b37bcaaacbc..ceb32c484c6d 100644 --- a/tests/python/unittest/test_tir_transform_inject_double_buffer.py +++ b/tests/python/unittest/test_tir_transform_inject_double_buffer.py @@ -47,8 +47,8 @@ def test_double_buffer(): mod = opt(mod) stmt = mod["db"].body - assert isinstance(stmt.body, tvm.tir.Allocate) - assert stmt.body.extents[0].value == 2 + assert isinstance(stmt.body.body, tvm.tir.Allocate) + assert stmt.body.body.extents[0].value == 2 f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] count = [0] diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index 673267a9b1fa..3e7a5a0cb300 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -49,13 +49,13 @@ def get_vthread(name): stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("vthread"))) - )["main"] + )["main"].body assert stmt.body.body.extents[0].value == 2 stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread"))) - )["main"] + )["main"].body assert len(stmt.body.body.extents) == 3 @@ -94,11 +94,11 @@ def get_vthread(name): stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], get_vthread("cthread"))) - )["main"] + )["main"].body assert stmt.body.body.extents[0].value == 2 - assert stmt.body.body.body.body.extents[0].value == 2 - assert len(stmt.body.body.body.body.extents) == 3 + assert stmt.body.body.body.body.body.body.extents[0].value == 2 + assert len(stmt.body.body.body.body.body.body.extents) == 3 def test_vthread_if_then_else(): @@ -119,7 +119,7 @@ def test_vthread_if_then_else(): stmt = tvm.tir.transform.InjectVirtualThread()( tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) - )["main"] + )["main"].body assert stmt.body.body.body[0].else_case != None assert stmt.body.body.body[1].else_case == None diff --git a/tests/python/unittest/test_tir_transform_lift_attr_scope.py b/tests/python/unittest/test_tir_transform_lift_attr_scope.py index 65e317dfbcb8..12ad16dfe092 100644 --- a/tests/python/unittest/test_tir_transform_lift_attr_scope.py +++ b/tests/python/unittest/test_tir_transform_lift_attr_scope.py @@ -38,7 +38,7 @@ def test_coproc_lift(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"] + body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body assert body.body.body.node == cp @@ -58,7 +58,7 @@ def test_coproc_lift(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"] + body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body assert body.body.body.body[1].node == cp assert len(body.body.body.body) == 2 diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index c632f744bb81..6194024748e0 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -178,7 +178,7 @@ def test_vectorize(): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) s[C].vectorize(x) - stmt = tvm.lower(s, [A, B], name="main")["main"] + stmt = tvm.lower(s, [A, B], name="main")["main"].body body = stmt.body.body.body.body assert x.var.name not in str(body.condition) assert any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp))) @@ -229,7 +229,7 @@ def test_thread_axis2(): _, x = s[C].split(x, factor=m) s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) - stmt = tvm.lower(s, [A, B], name="main")["main"] + stmt = tvm.lower(s, [A, B], name="main")["main"].body for_body = stmt.body.body.body.body[0] assert "threadIdx" not in str(for_body.extent) diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index 84bf0c4d52fd..f3baff120cf6 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -47,9 +47,8 @@ def test_lower_warp_memory_local_scope(): fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"] mod = tvm.IRModule.from_expr(fdevice) fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"] - allocate = fdevice.body.body - assert allocate.buffer_var.type_annotation.storage_scope == "local" - assert fdevice.body.body.extents[0].value == 2 + assert fdevice.body.body.value.value == "local" + assert fdevice.body.body.body.extents[0].value == 2 @tvm.testing.requires_cuda diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 0e9ab862a9c8..2d1fea01aa32 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -79,7 +79,7 @@ def test_flatten_storage_align(): )(mod) stmt = mod["main"].body - assert stmt.extents[0].value == 17 * 8 + assert stmt.body.extents[0].value == 17 * 8 def test_flatten_double_buffer(): @@ -114,8 +114,8 @@ def test_flatten_double_buffer(): )(mod) stmt = mod["main"].body - assert isinstance(stmt.body, tvm.tir.Allocate) - assert stmt.body.extents[0].value == 2 + assert isinstance(stmt.body.body, tvm.tir.Allocate) + assert stmt.body.body.extents[0].value == 2 mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, C], stmt).with_attr("global_symbol", "db")) f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 9e738b136b17..70e77ff69fea 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -298,9 +298,9 @@ def test_storage_share_gpu(): alloc_stats = {"global": 0, "shared": 0} def verify(n): - if isinstance(n, tvm.tir.Allocate): - scope = n.buffer_var.type_annotation.storage_scope - alloc_stats[scope] += 1 + if isinstance(n, tvm.tir.AttrStmt): + if n.attr_key == "storage_scope": + alloc_stats[n.value.value] += 1 tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert alloc_stats["global"] == 2 @@ -317,7 +317,7 @@ def test_parallel_alloc(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.StorageRewrite()(mod)["main"] + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body assert isinstance(body.body.body, tvm.tir.Allocate) @@ -334,7 +334,7 @@ def test_parallel_alloc(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) - body = tvm.tir.transform.StorageRewrite()(mod)["main"] + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body assert isinstance(body.body.body.body.body, tvm.tir.Allocate) @@ -356,6 +356,7 @@ def get_mod(kind="serial"): mod = get_mod(kind="parallel") # parallel (i, 0, n) { + # // attr [j] storage_scope = "global" # allocate j[int32 * 1] # j[0] = 0 # while((j[0] < 10)){ @@ -365,9 +366,11 @@ def get_mod(kind="serial"): # j[0] = (j[0] + (j[0] + 1)) # } # } - body = tvm.tir.transform.StorageRewrite()(mod)["main"] + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body # parallel (i, 0, n) { + # // attr [j] storage_scope = "global" # allocate j[int32 * 1] + # // attr [A] storage_scope = "global" # allocate A[float32 * n] # j[0] = 0 # while((j[0] < 10)){ @@ -376,10 +379,11 @@ def get_mod(kind="serial"): # } # } assert isinstance(body.body.body, tvm.tir.Allocate) # j - assert isinstance(body.body.body.body, tvm.tir.Allocate) # A + assert isinstance(body.body.body.body.body, tvm.tir.Allocate) # A mod = get_mod(kind="serial") # for (i, 0, n) { + # // attr [j] storage_scope = "global" # allocate j[int32 * 1] # j[0] = 0 # while((j[0] < 10)){ @@ -389,8 +393,10 @@ def get_mod(kind="serial"): # j[0] = (j[0] + (j[0] + 1)) # } # } - body = tvm.tir.transform.StorageRewrite()(mod)["main"] + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + # // attr [j] storage_scope = "global" # allocate j[int32 * 1] + # // attr [A] storage_scope = "global" # allocate A[float32 * n] # for (i, 0, n) { # j[0] = 0 @@ -400,7 +406,7 @@ def get_mod(kind="serial"): # } # } assert isinstance(body.body, tvm.tir.Allocate) # j - assert isinstance(body.body.body, tvm.tir.Allocate) # A + assert isinstance(body.body.body.body, tvm.tir.Allocate) # A def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024): diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 7fff6a804e4a..030c01713927 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -52,7 +52,7 @@ def test_thread_storage_sync(): mod = tvm.IRModule.from_expr(fdevice) cuda_target = tvm.target.Target("cuda") f = tvm.tir.transform.ThreadSync("shared")(mod)["test_kernel0"] - body_list = tvm.tir.stmt_list(f.body.body.body) + body_list = tvm.tir.stmt_list(f.body.body.body.body) assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")) From c2892b63dcd0c8e8d2397b0537c6cd5e662727d3 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 28 Jul 2021 11:00:33 +0900 Subject: [PATCH 10/21] handle vector alloc --- ...merge_dynamic_shared_memory_allocations.cc | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index 83f83b89c72c..fd5f87918a99 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -63,13 +63,11 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { // Allocate one dynamic shared memory allocation at the beginning of thread scope int align = 1; for (auto& alloc : dyn_shmem_allocs_) { - align = std::max(align, alloc->dtype.bytes()); + align = std::max(align, alloc->dtype.bytes() * alloc->dtype.lanes()); } for (auto& alloc : dyn_shmem_allocs_) { - buffer_offsets_[alloc->buffer_var.get()] = merged_alloc_size_; + buffer_byte_offsets_[alloc->buffer_var.get()] = merged_alloc_size_; merged_alloc_size_ += alloc->extents[0] * align; - LOG(INFO) << "buffer offset for " << alloc->buffer_var->name_hint << " = " - << buffer_offsets_[alloc->buffer_var.get()]; } allocated = true; @@ -108,22 +106,25 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { private: PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) { - auto it = buffer_offsets_.find(buffer_var.get()); - ICHECK(it != buffer_offsets_.end()); - return indexdiv(it->second, dtype.bytes()); + auto it = buffer_byte_offsets_.find(buffer_var.get()); + ICHECK(it != buffer_byte_offsets_.end()); + return indexdiv(it->second, dtype.bytes() * dtype.lanes()); } Var merged_buf_var_{"buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)), "shared.dyn")}; std::unordered_set dyn_shmem_allocs_; PrimExpr merged_alloc_size_{0}; - std::unordered_map buffer_offsets_; + std::unordered_map buffer_byte_offsets_; bool allocated{false}; }; Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) { AllocateCollector collector; collector(stmt); - return DynamicSharedMemoryRewriter(collector.dyn_shmem_allocs_)(std::move(stmt)); + if (collector.dyn_shmem_allocs_.size() > 0) { + return DynamicSharedMemoryRewriter(collector.dyn_shmem_allocs_)(std::move(stmt)); + } + return stmt; } namespace transform { From e8907f1ec210f985bdcecee247c7cee7e6fb85b5 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 28 Jul 2021 11:39:33 +0900 Subject: [PATCH 11/21] add vectorized test --- ...merge_dynamic_shared_memory_allocations.cc | 1 + tests/python/unittest/test_tir_ir_builder.py | 63 ++++++++++++++++++- 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index fd5f87918a99..8ecb89f60968 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -66,6 +66,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { align = std::max(align, alloc->dtype.bytes() * alloc->dtype.lanes()); } for (auto& alloc : dyn_shmem_allocs_) { + ICHECK_EQ(alloc->extents.size(), 1); buffer_byte_offsets_[alloc->buffer_var.get()] = merged_alloc_size_; merged_alloc_size_ += alloc->extents[0] * align; } diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index e148d4844a3e..3864bd5ea72d 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -633,6 +633,66 @@ def check_target(target): check_target(target) +@tvm.testing.requires_gpu +def test_dyn_shared_vectorized(): + n = te.size_var("n") + A = te.placeholder((n,), name="A", dtype="float32") + B = te.placeholder((n,), name="B", dtype="float32") + + def test_device_ir(A, B, C): + n = 512 # A.shape[0] + ib = tvm.tir.ir_builder.create() + + values_per_thread = 4 + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", tvm.tir.indexdiv(n, values_per_thread)) + + A_sh = ib.allocate(A.dtype, (n,), scope="shared") + B_sh = ib.allocate(B.dtype, (n,), scope="shared") + + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + + with ib.for_range(0, values_per_thread, kind="vectorize") as i: + A_sh[tx * values_per_thread + i] = Aptr[tx * values_per_thread + i] + B_sh[tx * values_per_thread + i] = Bptr[tx * values_per_thread + i] + + with ib.for_range(0, values_per_thread) as i: + Cptr[tx * values_per_thread + i] = ( + A_sh[tx * values_per_thread + i] + B_sh[tx * values_per_thread + i] + ) + + return ib.get() + + C = te.extern( + (n,), + [A, B], + lambda ins, outs: test_device_ir(ins[0], ins[1], outs[0]), + name="vadd", + dtype="float32", + ) + s = te.create_schedule(C.op) + + def check_target(target): + if not tvm.testing.device_enabled(target): + return + + fadd = tvm.build(s, [A, B, C], target) + dev = tvm.device(target, 0) + + for n in [512, 1024]: + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) + c = tvm.nd.array(np.zeros((n,), dtype=C.dtype), dev) + fadd(a, b, c) + tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy(), 1e-4, 1e-4) + break + + for target in ["cuda", "nvptx"]: + check_target(target) + + if __name__ == "__main__": # test_prefetch() # test_if() @@ -644,4 +704,5 @@ def check_target(target): # test_while_mandel() # test_while_binary_search() # test_dyn_shared() - test_matmul_dyn_shared() + # test_matmul_dyn_shared() + test_dyn_shared_vectorized() From 1f5c79ce3932e9bc7f0858946158bd70f505681b Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 28 Jul 2021 13:20:03 +0900 Subject: [PATCH 12/21] drop multi-lane dtype allocation support, vectorized store working --- ...merge_dynamic_shared_memory_allocations.cc | 5 ++- tests/python/unittest/test_tir_ir_builder.py | 42 ++++++++++--------- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index 8ecb89f60968..980b9680ade2 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -63,7 +63,8 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { // Allocate one dynamic shared memory allocation at the beginning of thread scope int align = 1; for (auto& alloc : dyn_shmem_allocs_) { - align = std::max(align, alloc->dtype.bytes() * alloc->dtype.lanes()); + ICHECK_EQ(alloc->dtype.lanes(), 1) << "vector dtype allocation not supported."; + align = std::max(align, alloc->dtype.bytes()); } for (auto& alloc : dyn_shmem_allocs_) { ICHECK_EQ(alloc->extents.size(), 1); @@ -109,7 +110,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) { auto it = buffer_byte_offsets_.find(buffer_var.get()); ICHECK(it != buffer_byte_offsets_.end()); - return indexdiv(it->second, dtype.bytes() * dtype.lanes()); + return indexdiv(it->second, dtype.bytes()); } Var merged_buf_var_{"buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)), "shared.dyn")}; diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 3864bd5ea72d..7cfc2d724d1b 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -634,21 +634,22 @@ def check_target(target): @tvm.testing.requires_gpu -def test_dyn_shared_vectorized(): +def test_dyn_shared_vectorized_store(): + """Test vectorized store into dynamic shared memory""" n = te.size_var("n") - A = te.placeholder((n,), name="A", dtype="float32") + A = te.placeholder((n,), name="A", dtype="float16") B = te.placeholder((n,), name="B", dtype="float32") def test_device_ir(A, B, C): - n = 512 # A.shape[0] + n = A.shape[0] ib = tvm.tir.ir_builder.create() values_per_thread = 4 tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", tvm.tir.indexdiv(n, values_per_thread)) - A_sh = ib.allocate(A.dtype, (n,), scope="shared") - B_sh = ib.allocate(B.dtype, (n,), scope="shared") + A_sh = ib.allocate(A.dtype, (n,), scope="shared.dyn") + B_sh = ib.allocate(B.dtype, (n,), scope="shared.dyn") Aptr = ib.buffer_ptr(A) Bptr = ib.buffer_ptr(B) @@ -660,7 +661,7 @@ def test_device_ir(A, B, C): with ib.for_range(0, values_per_thread) as i: Cptr[tx * values_per_thread + i] = ( - A_sh[tx * values_per_thread + i] + B_sh[tx * values_per_thread + i] + cast(A_sh[tx * values_per_thread + i], "float32") + B_sh[tx * values_per_thread + i] ) return ib.get() @@ -686,23 +687,24 @@ def check_target(target): b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) c = tvm.nd.array(np.zeros((n,), dtype=C.dtype), dev) fadd(a, b, c) - tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy(), 1e-4, 1e-4) - break + tvm.testing.assert_allclose( + c.numpy(), a.numpy().astype("float32") + b.numpy(), 1e-4, 1e-4 + ) for target in ["cuda", "nvptx"]: check_target(target) if __name__ == "__main__": - # test_prefetch() - # test_if() - # test_for() - # test_cpu() - # test_gpu() - # test_while_vectorize() - # test_while_collatz() - # test_while_mandel() - # test_while_binary_search() - # test_dyn_shared() - # test_matmul_dyn_shared() - test_dyn_shared_vectorized() + test_prefetch() + test_if() + test_for() + test_cpu() + test_gpu() + test_while_vectorize() + test_while_collatz() + test_while_mandel() + test_while_binary_search() + test_dyn_shared() + test_matmul_dyn_shared() + test_dyn_shared_vectorized_store() From 6d6f437532d078a03261734f0d7ceffc0a0676ab Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 28 Jul 2021 13:25:18 +0900 Subject: [PATCH 13/21] doc update --- python/tvm/tir/transform/transform.py | 3 ++- .../transforms/merge_dynamic_shared_memory_allocations.cc | 6 ++---- tests/python/unittest/test_tir_ir_builder.py | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 42efdf2e9ba9..970491a1a306 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -658,7 +658,8 @@ def FlattenBuffer(): def MergeDynamicSharedMemoryAllocations(): - """TODO + """This pass merges multiple TIR-level dynamic shared memory allocations + into one allocation. Returns ------- diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index 980b9680ade2..78697e53373d 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -19,10 +19,10 @@ /*! * \file merge_dynamic_shared_memory_allocations.cc + * \brief Each GPU kernel is allowed to have only one dynamic shared memory allocation. + * This pass merges multiple TIR-level dynamic shared memory allocations into one allocation. */ #include -#include -#include #include #include #include @@ -134,9 +134,7 @@ namespace transform { Pass MergeDynamicSharedMemoryAllocations() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - LOG(INFO) << "Before: " << f; n->body = MergeDynamicSharedMemoryAllocations(std::move(n->body)); - LOG(INFO) << "After: " << f; return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.MergeDynamicSharedMemoryAllocations", {}); diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 7cfc2d724d1b..0bdaf279cbb8 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -648,8 +648,8 @@ def test_device_ir(A, B, C): tx = te.thread_axis("threadIdx.x") ib.scope_attr(tx, "thread_extent", tvm.tir.indexdiv(n, values_per_thread)) - A_sh = ib.allocate(A.dtype, (n,), scope="shared.dyn") - B_sh = ib.allocate(B.dtype, (n,), scope="shared.dyn") + A_sh = ib.allocate(A.dtype, (n,), scope="shared.dyn") # fp16 + B_sh = ib.allocate(B.dtype, (n,), scope="shared.dyn") # fp32 Aptr = ib.buffer_ptr(A) Bptr = ib.buffer_ptr(B) From f086e37236560f2ae330dc9073ef6e8784a256dc Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 28 Jul 2021 13:33:19 +0900 Subject: [PATCH 14/21] dtype fix in the test --- tests/python/unittest/test_tir_ir_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 0bdaf279cbb8..72c207f2709e 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -557,8 +557,8 @@ def check_target(target): @tvm.testing.requires_gpu def test_matmul_dyn_shared(): n = 1024 - A = te.placeholder((n, n), name="A", dtype="float32") - B = te.placeholder((n, n), name="B", dtype="float32") + A = te.placeholder((n, n), name="A", dtype="float16") + B = te.placeholder((n, n), name="B", dtype="float16") def syncthread(): return tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])) From 391b2c66f652bdfd249daf165427e9beb319b792 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 28 Jul 2021 13:44:47 +0900 Subject: [PATCH 15/21] lint fix, do not run merging when number of alloc is 1 --- .../transforms/merge_dynamic_shared_memory_allocations.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index 78697e53373d..1177b3baa9ac 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -55,7 +55,8 @@ class AllocateCollector : public StmtExprVisitor { class DynamicSharedMemoryRewriter : public StmtExprMutator { public: - DynamicSharedMemoryRewriter(const std::unordered_set& dyn_shmem_allocs) + explicit DynamicSharedMemoryRewriter( + const std::unordered_set& dyn_shmem_allocs) : dyn_shmem_allocs_{dyn_shmem_allocs} {} Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -123,7 +124,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) { AllocateCollector collector; collector(stmt); - if (collector.dyn_shmem_allocs_.size() > 0) { + if (collector.dyn_shmem_allocs_.size() > 1) { return DynamicSharedMemoryRewriter(collector.dyn_shmem_allocs_)(std::move(stmt)); } return stmt; From 6203b0bb3960969aeb4725ab16dbb2956c808ed3 Mon Sep 17 00:00:00 2001 From: masahi Date: Thu, 29 Jul 2021 14:54:18 +0900 Subject: [PATCH 16/21] Update src/tir/transforms/merge_dynamic_shared_memory_allocations.cc Co-authored-by: Siyuan Feng --- src/tir/transforms/merge_dynamic_shared_memory_allocations.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index 1177b3baa9ac..8a54f591a34b 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -63,7 +63,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { if (op->attr_key == attr::thread_extent && !allocated) { // Allocate one dynamic shared memory allocation at the beginning of thread scope int align = 1; - for (auto& alloc : dyn_shmem_allocs_) { + for (const auto& alloc : dyn_shmem_allocs_) { ICHECK_EQ(alloc->dtype.lanes(), 1) << "vector dtype allocation not supported."; align = std::max(align, alloc->dtype.bytes()); } From 56f8dce78481364bf92e1841b4da13f9e0ef30d2 Mon Sep 17 00:00:00 2001 From: masahi Date: Thu, 29 Jul 2021 14:54:30 +0900 Subject: [PATCH 17/21] Update src/tir/transforms/merge_dynamic_shared_memory_allocations.cc Co-authored-by: Siyuan Feng --- src/tir/transforms/merge_dynamic_shared_memory_allocations.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index 8a54f591a34b..e8865b260dc1 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -67,7 +67,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { ICHECK_EQ(alloc->dtype.lanes(), 1) << "vector dtype allocation not supported."; align = std::max(align, alloc->dtype.bytes()); } - for (auto& alloc : dyn_shmem_allocs_) { + for (const auto& alloc : dyn_shmem_allocs_) { ICHECK_EQ(alloc->extents.size(), 1); buffer_byte_offsets_[alloc->buffer_var.get()] = merged_alloc_size_; merged_alloc_size_ += alloc->extents[0] * align; From ec3d0da1d5253bafdabecce0bcd7f67e8c992087 Mon Sep 17 00:00:00 2001 From: masa Date: Thu, 29 Jul 2021 14:16:14 +0900 Subject: [PATCH 18/21] move test cases into dedicated file --- tests/python/unittest/test_tir_ir_builder.py | 143 --------------- ...merge_dynamic_shared_memory_allocations.py | 167 ++++++++++++++++++ 2 files changed, 167 insertions(+), 143 deletions(-) create mode 100644 tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 72c207f2709e..0329134bb3fa 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -554,147 +554,6 @@ def check_target(target): check_target(target) -@tvm.testing.requires_gpu -def test_matmul_dyn_shared(): - n = 1024 - A = te.placeholder((n, n), name="A", dtype="float16") - B = te.placeholder((n, n), name="B", dtype="float16") - - def syncthread(): - return tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])) - - def test_matmul_ir(A, B, C): - ib = tvm.tir.ir_builder.create() - block = 16 - - tx = te.thread_axis("threadIdx.x") - ty = te.thread_axis("threadIdx.y") - bx = te.thread_axis("blockIdx.x") - by = te.thread_axis("blockIdx.y") - ib.scope_attr(tx, "thread_extent", block) - ib.scope_attr(ty, "thread_extent", block) - ib.scope_attr(bx, "thread_extent", n / block) - ib.scope_attr(by, "thread_extent", n / block) - - A_sh = ib.allocate(A.dtype, (block, block), scope="shared.dyn", name="A_sh") # fp16 - B_sh = ib.allocate(B.dtype, (block, block), scope="shared.dyn", name="B_sh") # fp16 - # Create a dynamic shared memory for the accumulation. - # This is for testing merging dynamic shared memory alloctions with different data type. - # In practice, there is no need to allocate a shared memory for C. - C_sh = ib.allocate(C.dtype, (block, block), scope="shared.dyn", name="C_sh") # fp32 - - A_ptr = ib.buffer_ptr(A) - B_ptr = ib.buffer_ptr(B) - C_ptr = ib.buffer_ptr(C) - - C_sh[ty, tx] = 0.0 - - with ib.for_range(0, n // block, name="i") as i: - A_sh[ty, tx] = A_ptr[by * block + ty, i * block + tx] - B_sh[ty, tx] = B_ptr[i * block + ty, bx * block + tx] - ib.emit(syncthread()) - - with ib.for_range(0, block, name="k") as k: - C_sh[ty, tx] += cast(A_sh[ty, k] * B_sh[k, tx], "float32") - - ib.emit(syncthread()) - - C_ptr[by * block + ty, bx * block + tx] = C_sh[ty, tx] - - return ib.get() - - C = te.extern( - A.shape, - [A, B], - lambda ins, outs: test_matmul_ir(ins[0], ins[1], outs[0]), - name="reduce", - dtype="float32", - ) - s = te.create_schedule(C.op) - - def check_target(target): - if not tvm.testing.device_enabled(target): - return - - fmatmul = tvm.build(s, [A, B, C], target) - dev = tvm.device(target, 0) - - size = (n, n) - a_np = np.random.uniform(size=size).astype(A.dtype) - b_np = np.random.uniform(size=size).astype(B.dtype) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) - c = tvm.nd.array(np.zeros(size, dtype=C.dtype), dev) - fmatmul(a, b, c) - np_ref = np.dot(a_np.astype("float32"), b_np.astype("float32")) - tvm.testing.assert_allclose(c.numpy(), np_ref, 1e-4, 1e-4) - - for target in ["cuda", "nvptx"]: - check_target(target) - - -@tvm.testing.requires_gpu -def test_dyn_shared_vectorized_store(): - """Test vectorized store into dynamic shared memory""" - n = te.size_var("n") - A = te.placeholder((n,), name="A", dtype="float16") - B = te.placeholder((n,), name="B", dtype="float32") - - def test_device_ir(A, B, C): - n = A.shape[0] - ib = tvm.tir.ir_builder.create() - - values_per_thread = 4 - tx = te.thread_axis("threadIdx.x") - ib.scope_attr(tx, "thread_extent", tvm.tir.indexdiv(n, values_per_thread)) - - A_sh = ib.allocate(A.dtype, (n,), scope="shared.dyn") # fp16 - B_sh = ib.allocate(B.dtype, (n,), scope="shared.dyn") # fp32 - - Aptr = ib.buffer_ptr(A) - Bptr = ib.buffer_ptr(B) - Cptr = ib.buffer_ptr(C) - - with ib.for_range(0, values_per_thread, kind="vectorize") as i: - A_sh[tx * values_per_thread + i] = Aptr[tx * values_per_thread + i] - B_sh[tx * values_per_thread + i] = Bptr[tx * values_per_thread + i] - - with ib.for_range(0, values_per_thread) as i: - Cptr[tx * values_per_thread + i] = ( - cast(A_sh[tx * values_per_thread + i], "float32") + B_sh[tx * values_per_thread + i] - ) - - return ib.get() - - C = te.extern( - (n,), - [A, B], - lambda ins, outs: test_device_ir(ins[0], ins[1], outs[0]), - name="vadd", - dtype="float32", - ) - s = te.create_schedule(C.op) - - def check_target(target): - if not tvm.testing.device_enabled(target): - return - - fadd = tvm.build(s, [A, B, C], target) - dev = tvm.device(target, 0) - - for n in [512, 1024]: - a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) - c = tvm.nd.array(np.zeros((n,), dtype=C.dtype), dev) - fadd(a, b, c) - tvm.testing.assert_allclose( - c.numpy(), a.numpy().astype("float32") + b.numpy(), 1e-4, 1e-4 - ) - - for target in ["cuda", "nvptx"]: - check_target(target) - - if __name__ == "__main__": test_prefetch() test_if() @@ -706,5 +565,3 @@ def check_target(target): test_while_mandel() test_while_binary_search() test_dyn_shared() - test_matmul_dyn_shared() - test_dyn_shared_vectorized_store() diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py new file mode 100644 index 000000000000..56a402a9189b --- /dev/null +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -0,0 +1,167 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te +import numpy as np +import tvm.testing +from tvm.topi.math import cast + + +@tvm.testing.requires_gpu +def test_matmul_dyn_shared(): + n = 1024 + A = te.placeholder((n, n), name="A", dtype="float16") + B = te.placeholder((n, n), name="B", dtype="float16") + + def syncthread(): + return tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])) + + def test_matmul_ir(A, B, C): + ib = tvm.tir.ir_builder.create() + block = 16 + + tx = te.thread_axis("threadIdx.x") + ty = te.thread_axis("threadIdx.y") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", block) + ib.scope_attr(ty, "thread_extent", block) + ib.scope_attr(bx, "thread_extent", n / block) + ib.scope_attr(by, "thread_extent", n / block) + + A_sh = ib.allocate(A.dtype, (block, block), scope="shared.dyn", name="A_sh") # fp16 + B_sh = ib.allocate(B.dtype, (block, block), scope="shared.dyn", name="B_sh") # fp16 + # Create a dynamic shared memory for the accumulation. + # This is for testing merging dynamic shared memory alloctions with different data type. + # In practice, there is no need to allocate a shared memory for C. + C_sh = ib.allocate(C.dtype, (block, block), scope="shared.dyn", name="C_sh") # fp32 + + A_ptr = ib.buffer_ptr(A) + B_ptr = ib.buffer_ptr(B) + C_ptr = ib.buffer_ptr(C) + + C_sh[ty, tx] = 0.0 + + with ib.for_range(0, n // block, name="i") as i: + A_sh[ty, tx] = A_ptr[by * block + ty, i * block + tx] + B_sh[ty, tx] = B_ptr[i * block + ty, bx * block + tx] + ib.emit(syncthread()) + + with ib.for_range(0, block, name="k") as k: + C_sh[ty, tx] += cast(A_sh[ty, k] * B_sh[k, tx], "float32") + + ib.emit(syncthread()) + + C_ptr[by * block + ty, bx * block + tx] = C_sh[ty, tx] + + return ib.get() + + C = te.extern( + A.shape, + [A, B], + lambda ins, outs: test_matmul_ir(ins[0], ins[1], outs[0]), + name="matmul", + dtype="float32", + ) + s = te.create_schedule(C.op) + + def check_target(target): + if not tvm.testing.device_enabled(target): + return + + fmatmul = tvm.build(s, [A, B, C], target) + dev = tvm.device(target, 0) + + size = (n, n) + a_np = np.random.uniform(size=size).astype(A.dtype) + b_np = np.random.uniform(size=size).astype(B.dtype) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros(size, dtype=C.dtype), dev) + fmatmul(a, b, c) + np_ref = np.dot(a_np.astype("float32"), b_np.astype("float32")) + tvm.testing.assert_allclose(c.numpy(), np_ref, 1e-4, 1e-4) + + for target in ["cuda", "nvptx"]: + check_target(target) + + +@tvm.testing.requires_gpu +def test_dyn_shared_vectorized_store(): + """Test vectorized store into dynamic shared memory""" + n = te.size_var("n") + A = te.placeholder((n,), name="A", dtype="float16") + B = te.placeholder((n,), name="B", dtype="float32") + + def test_device_ir(A, B, C): + n = A.shape[0] + ib = tvm.tir.ir_builder.create() + + values_per_thread = 4 + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", tvm.tir.indexdiv(n, values_per_thread)) + + A_sh = ib.allocate(A.dtype, (n,), scope="shared.dyn") # fp16 + B_sh = ib.allocate(B.dtype, (n,), scope="shared.dyn") # fp32 + + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + + with ib.for_range(0, values_per_thread, kind="vectorize") as i: + A_sh[tx * values_per_thread + i] = Aptr[tx * values_per_thread + i] + B_sh[tx * values_per_thread + i] = Bptr[tx * values_per_thread + i] + + with ib.for_range(0, values_per_thread) as i: + Cptr[tx * values_per_thread + i] = ( + cast(A_sh[tx * values_per_thread + i], "float32") + B_sh[tx * values_per_thread + i] + ) + + return ib.get() + + C = te.extern( + (n,), + [A, B], + lambda ins, outs: test_device_ir(ins[0], ins[1], outs[0]), + name="vadd", + dtype="float32", + ) + s = te.create_schedule(C.op) + + def check_target(target): + if not tvm.testing.device_enabled(target): + return + + fadd = tvm.build(s, [A, B, C], target) + dev = tvm.device(target, 0) + + for n in [512, 1024]: + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) + c = tvm.nd.array(np.zeros((n,), dtype=C.dtype), dev) + fadd(a, b, c) + tvm.testing.assert_allclose( + c.numpy(), a.numpy().astype("float32") + b.numpy(), 1e-4, 1e-4 + ) + + for target in ["cuda", "nvptx"]: + check_target(target) + + +if __name__ == "__main__": + test_matmul_dyn_shared() + test_dyn_shared_vectorized_store() From 30bbbbb1c044b3fa3823381b4e217cbf9a0a0d4a Mon Sep 17 00:00:00 2001 From: masa Date: Thu, 29 Jul 2021 14:18:33 +0900 Subject: [PATCH 19/21] use integer division --- ...t_tir_transform_merge_dynamic_shared_memory_allocations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py index 56a402a9189b..a585010ee0eb 100644 --- a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -40,8 +40,8 @@ def test_matmul_ir(A, B, C): by = te.thread_axis("blockIdx.y") ib.scope_attr(tx, "thread_extent", block) ib.scope_attr(ty, "thread_extent", block) - ib.scope_attr(bx, "thread_extent", n / block) - ib.scope_attr(by, "thread_extent", n / block) + ib.scope_attr(bx, "thread_extent", n // block) + ib.scope_attr(by, "thread_extent", n // block) A_sh = ib.allocate(A.dtype, (block, block), scope="shared.dyn", name="A_sh") # fp16 B_sh = ib.allocate(B.dtype, (block, block), scope="shared.dyn", name="B_sh") # fp16 From b46e1f20f7144a29ba6946c99b4ac98b12845857 Mon Sep 17 00:00:00 2001 From: masa Date: Thu, 29 Jul 2021 15:11:45 +0900 Subject: [PATCH 20/21] verify new transform pass output --- ...merge_dynamic_shared_memory_allocations.py | 45 ++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py index a585010ee0eb..8f008ec0f94a 100644 --- a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -21,9 +21,47 @@ from tvm.topi.math import cast +def run_passes(sch, args): + bounds = tvm.te.schedule.InferBound(sch) + assert isinstance(bounds, tvm.container.Map) + stmt = tvm.te.schedule.ScheduleOps(sch, bounds) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None) + mod = tvm.IRModule.from_expr(func) + return tvm.transform.Sequential( + [ + tvm.tir.transform.StorageFlatten(64), + tvm.tir.transform.Simplify(), + tvm.tir.transform.VectorizeLoop(), + tvm.tir.transform.StorageRewrite(), + tvm.tir.transform.MergeDynamicSharedMemoryAllocations(), + ] + )(mod) + + +def verify_single_allocation(stmt, alloc_size=None): + num_alloc = [0] + alloc_extents = [] + + def verify(n): + if ( + isinstance(n, tvm.tir.Allocate) + and n.buffer_var.type_annotation.storage_scope == "shared.dyn" + ): + num_alloc[0] += 1 + alloc_extents.append(n.extents[0]) + + tvm.tir.stmt_functor.post_order_visit(stmt, verify) + assert num_alloc[0] == 1 + + if alloc_size: + assert alloc_extents[0] == alloc_size + + @tvm.testing.requires_gpu def test_matmul_dyn_shared(): n = 1024 + block = 16 A = te.placeholder((n, n), name="A", dtype="float16") B = te.placeholder((n, n), name="B", dtype="float16") @@ -32,7 +70,6 @@ def syncthread(): def test_matmul_ir(A, B, C): ib = tvm.tir.ir_builder.create() - block = 16 tx = te.thread_axis("threadIdx.x") ty = te.thread_axis("threadIdx.y") @@ -78,6 +115,9 @@ def test_matmul_ir(A, B, C): dtype="float32", ) s = te.create_schedule(C.op) + mod = run_passes(s, [A, B, C]) + expected_alloc_size = block * block * 3 * 4 + verify_single_allocation(mod["main"].body, expected_alloc_size) def check_target(target): if not tvm.testing.device_enabled(target): @@ -142,6 +182,9 @@ def test_device_ir(A, B, C): ) s = te.create_schedule(C.op) + mod = run_passes(s, [A, B, C]) + verify_single_allocation(mod["main"].body) + def check_target(target): if not tvm.testing.device_enabled(target): return From 7a37406d10ed5b55bbd460d4bd9541c04911d533 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 30 Jul 2021 09:58:49 +0900 Subject: [PATCH 21/21] add test on combined buffer reuse and merge --- ...merge_dynamic_shared_memory_allocations.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py index 8f008ec0f94a..9c511f1de6b9 100644 --- a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -205,6 +205,55 @@ def check_target(target): check_target(target) +@tvm.testing.requires_gpu +def test_dyn_shared_reuse_and_merge(): + n = 64 + A = te.placeholder((n,), name="A", dtype="float32") + B = te.placeholder((n,), name="B", dtype="float32") + C = te.placeholder((te.size_var("n_dyn"),), name="C", dtype="float32") + + def test_device_ir(A, B, C, D): + ib = tvm.tir.ir_builder.create() + + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", n) + + A_sh = ib.allocate(A.dtype, (n,), scope="shared.dyn", name="A_sh") + B_sh = ib.allocate(B.dtype, (n,), scope="shared.dyn", name="B_sh") + C_sh = ib.allocate(C.dtype, (C.shape[0],), scope="shared.dyn", name="C_sh") + + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + Dptr = ib.buffer_ptr(D) + + A_sh[tx] = Aptr[tx] + Dptr[tx] = A_sh[tx] + + B_sh[tx] = Bptr[tx] + Dptr[tx] += B_sh[tx] + + C_sh[tx] = Cptr[tx] # C cannot reuse other buffers since it size is dynamic + Dptr[tx] += C_sh[tx] + + return ib.get() + + D = te.extern( + (n,), + [A, B, C], + lambda ins, outs: test_device_ir(ins[0], ins[1], ins[2], outs[0]), + name="vadd", + dtype="float32", + ) + s = te.create_schedule(D.op) + + mod = run_passes(s, [A, B, C, D]) + # merged allocation + # allocate(buf_dyn_shmem: Pointer(shared.dyn uint8), uint8, [((n_dyn*4) + 256)]); + verify_single_allocation(mod["main"].body) + + if __name__ == "__main__": test_matmul_dyn_shared() test_dyn_shared_vectorized_store() + test_dyn_shared_reuse_and_merge()