From 22706a643c7ab0a802d1e06f0a68130ed500b3ca Mon Sep 17 00:00:00 2001 From: Yining Shi Date: Fri, 8 Dec 2023 06:45:26 +0000 Subject: [PATCH 1/4] [BugFix][TIR] Fix dynamic smem merge leaf alloc --- ...merge_dynamic_shared_memory_allocations.cc | 11 ++++++++- ...merge_dynamic_shared_memory_allocations.py | 24 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index 99055cebf2dc..61b086d085f0 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -447,9 +447,13 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { // - leaf stmt(offset = 0) // - end of scope(offset < 0) // In both cases, we need to handle the kill event correctly + auto is_leaf_alloc = [&] (const VarNode* var) { + return seq[i].scope_pair_offset == 0 && + std::find(it->second.gen.begin(), it->second.gen.end(), var) != it->second.gen.end(); + }; if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) { for (const VarNode* var : it->second.kill) { - this->Free(var); + if (!is_leaf_alloc(var)) this->Free(var); } } // scope_pair_offset >= 0 means it is either @@ -464,6 +468,11 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { alloc_map_[var] = dst_entry; } } + if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) { + for (const VarNode* var : it->second.kill) { + if (is_leaf_alloc(var)) this->Free(var); + } + } } } /*! diff --git a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py index 37372059a296..24720ac0da54 100644 --- a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -453,6 +453,30 @@ def func( return func +class TestLeafAllocFree(tvm.testing.CompareBeforeAfter): + transform = tvm.tir.transform.MergeDynamicSharedMemoryAllocations() + + def before(self): + @T.prim_func + def func(): + threadIdx_x = T.launch_thread("threadIdx.x", 128) + A_sh_data = T.allocate([128], "float32", "shared.dyn") + B_sh_data = T.allocate([128], "float32", "shared.dyn") + A_sh = T.decl_buffer([128], "float32", data=A_sh_data, scope="shared.dyn") + B_sh = T.decl_buffer([128], "float32", data=B_sh_data, scope="shared.dyn") + B_sh[threadIdx_x] = A_sh[threadIdx_x] + return func + + def expected(self): + @T.prim_func + def func(): + threadIdx_x = T.launch_thread("threadIdx.x", 128) + buf_dyn_shmem = T.allocate([1024], "uint8", "shared.dyn") + A_sh = T.decl_buffer((128,), data=buf_dyn_shmem, scope="shared.dyn") + B_sh = T.decl_buffer((128,), data=buf_dyn_shmem, scope="shared.dyn") + B_sh[threadIdx_x + 128] = A_sh[threadIdx_x] + return func + if __name__ == "__main__": tvm.testing.main() From a80abf8a48497e7b0b1beb1e1cd1900a724a2e94 Mon Sep 17 00:00:00 2001 From: Yining Shi Date: Fri, 8 Dec 2023 09:04:35 +0000 Subject: [PATCH 2/4] lint --- ...t_tir_transform_merge_dynamic_shared_memory_allocations.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py index 24720ac0da54..b18894ad4f52 100644 --- a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -376,7 +376,6 @@ def func( C_local[0] = T.float32(0) for i in range(64): - A_sh[threadIdx_y * 16 + threadIdx_x] = A_flat[ blockIdx_y * 16384 + threadIdx_y * 1024 + i * 16 + threadIdx_x ] @@ -453,6 +452,7 @@ def func( return func + class TestLeafAllocFree(tvm.testing.CompareBeforeAfter): transform = tvm.tir.transform.MergeDynamicSharedMemoryAllocations() @@ -465,6 +465,7 @@ def func(): A_sh = T.decl_buffer([128], "float32", data=A_sh_data, scope="shared.dyn") B_sh = T.decl_buffer([128], "float32", data=B_sh_data, scope="shared.dyn") B_sh[threadIdx_x] = A_sh[threadIdx_x] + return func def expected(self): @@ -475,6 +476,7 @@ def func(): A_sh = T.decl_buffer((128,), data=buf_dyn_shmem, scope="shared.dyn") B_sh = T.decl_buffer((128,), data=buf_dyn_shmem, scope="shared.dyn") B_sh[threadIdx_x + 128] = A_sh[threadIdx_x] + return func From ae06f22b95df9a583abbc47f3947b78f85ec2001 Mon Sep 17 00:00:00 2001 From: Yining Shi Date: Fri, 8 Dec 2023 09:09:51 +0000 Subject: [PATCH 3/4] lint --- src/tir/transforms/merge_dynamic_shared_memory_allocations.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index 61b086d085f0..9ae6126bff01 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -447,9 +447,9 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { // - leaf stmt(offset = 0) // - end of scope(offset < 0) // In both cases, we need to handle the kill event correctly - auto is_leaf_alloc = [&] (const VarNode* var) { + auto is_leaf_alloc = [&](const VarNode* var) { return seq[i].scope_pair_offset == 0 && - std::find(it->second.gen.begin(), it->second.gen.end(), var) != it->second.gen.end(); + std::find(it->second.gen.begin(), it->second.gen.end(), var) != it->second.gen.end(); }; if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) { for (const VarNode* var : it->second.kill) { From 3170c6edf3d0568ebd64ed8b649f175bba63b46a Mon Sep 17 00:00:00 2001 From: Yining Shi Date: Sun, 10 Dec 2023 07:40:58 +0000 Subject: [PATCH 4/4] fix allocs not correctly updated. --- ...merge_dynamic_shared_memory_allocations.cc | 1 + ...merge_dynamic_shared_memory_allocations.py | 39 +++++++++++++++++-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index 9ae6126bff01..6c7b0f649cfe 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -519,6 +519,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { StorageEntry* e = it->second; e->const_nbits = std::max(const_nbits, e->const_nbits); const_free_map_.erase(it); + it->second->allocs.push_back({op->buffer_var.get()}); return e; } // Then start looking at smaller buffers. diff --git a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py index b18894ad4f52..5dc5ec863fc1 100644 --- a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -453,7 +453,9 @@ def func( return func -class TestLeafAllocFree(tvm.testing.CompareBeforeAfter): +class TestSimpleAllocNoReuse(tvm.testing.CompareBeforeAfter): + """Test alloc and free within the same scope.""" + transform = tvm.tir.transform.MergeDynamicSharedMemoryAllocations() def before(self): @@ -462,8 +464,8 @@ def func(): threadIdx_x = T.launch_thread("threadIdx.x", 128) A_sh_data = T.allocate([128], "float32", "shared.dyn") B_sh_data = T.allocate([128], "float32", "shared.dyn") - A_sh = T.decl_buffer([128], "float32", data=A_sh_data, scope="shared.dyn") - B_sh = T.decl_buffer([128], "float32", data=B_sh_data, scope="shared.dyn") + A_sh = T.decl_buffer([128], data=A_sh_data, scope="shared.dyn") + B_sh = T.decl_buffer([128], data=B_sh_data, scope="shared.dyn") B_sh[threadIdx_x] = A_sh[threadIdx_x] return func @@ -480,5 +482,36 @@ def func(): return func +class TestSimpleAllocReuse(tvm.testing.CompareBeforeAfter): + """Test alloc and free within the same scope with a reuse chance.""" + + transform = tvm.tir.transform.MergeDynamicSharedMemoryAllocations() + + def before(self): + @T.prim_func + def func(): + threadIdx_x = T.launch_thread("threadIdx.x", 128) + A_sh_data = T.allocate([128], "float32", "shared.dyn") + B_sh_data = T.allocate([128], "float32", "shared.dyn") + A_sh = T.decl_buffer([128], data=A_sh_data, scope="shared.dyn") + B_sh = T.decl_buffer([128], data=B_sh_data, scope="shared.dyn") + A_sh[threadIdx_x] = 0 + B_sh[threadIdx_x] = 0 + + return func + + def expected(self): + @T.prim_func + def func(): + threadIdx_x = T.launch_thread("threadIdx.x", 128) + buf_dyn_shmem = T.allocate([512], "uint8", "shared.dyn") + A_sh = T.decl_buffer((128,), data=buf_dyn_shmem, scope="shared.dyn") + B_sh = T.decl_buffer((128,), data=buf_dyn_shmem, scope="shared.dyn") + A_sh[threadIdx_x] = 0 + B_sh[threadIdx_x] = 0 + + return func + + if __name__ == "__main__": tvm.testing.main()