diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index c21afe400c56..d92986e51a9c 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -50,11 +50,27 @@ class ThreadSyncPlanner : public StorageAccessVisitor { } // Plan the sync std::vector Summarize(std::vector seq, const ForNode* loop) final { + // Redirect all "shared.dyn" buffer access to the same buffer var + // so that the accesses can be planned together. + Var shared_dyn_buf; + for (StmtEntry& entry : seq) { + for (AccessEntry& access : entry.access) { + if (access.scope.rank == StorageRank::kShared && access.scope.tag == ".dyn" && + access.buffer.defined()) { + if (!shared_dyn_buf.defined()) { + shared_dyn_buf = access.buffer; + } else { + access.buffer = shared_dyn_buf; + } + } + } + } + // Unsynced reads and writes std::vector reads; std::vector writes; // if it is a loop, rotate two times to consider effect of loop. - // simulation based approach to find dependenceies + // simulation based approach to find dependencies for (size_t i = 0; i < seq.size(); ++i) { const StmtEntry& s = seq[i]; // check if sync before statement is needed. diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 57ea223cf984..571927dffe6e 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -119,7 +119,49 @@ def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) assert "T.tvm_storage_sync" in str(mod) +def test_sync_shared_dyn(): + @T.prim_func(private=True) + def func(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")): + blockIdx_x = T.launch_thread("blockIdx.x", 1) + B = T.allocate([24], "float32", "shared.dyn") + C = T.allocate([1], "float32", "local") + D = T.allocate([16], "float32", "shared.dyn") + threadIdx_x = T.launch_thread("threadIdx.x", 16) + B_1 = T.Buffer((24,), data=B, scope="shared.dyn") + A_1 = T.Buffer((16,), data=A.data) + B_1[threadIdx_x // 4 * 6 + threadIdx_x % 4] = A_1[threadIdx_x] + C_1 = T.Buffer((1,), data=C, scope="local") + C_1[0] = B_1[threadIdx_x // 4 * 6 + threadIdx_x % 4] + D_1 = T.Buffer((16,), data=D, scope="shared.dyn") + D_1[threadIdx_x] = C_1[0] + E_1 = T.Buffer((16,), data=E.data) + E_1[threadIdx_x] = D_1[threadIdx_x] + + @T.prim_func(private=True) + def expected(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")): + blockIdx_x = T.launch_thread("blockIdx.x", 1) + B_1 = T.allocate([24], "float32", "shared.dyn") + C_1 = T.allocate([1], "float32", "local") + D_1 = T.allocate([16], "float32", "shared.dyn") + threadIdx_x = T.launch_thread("threadIdx.x", 16) + B_1_1 = T.Buffer((24,), data=B_1, scope="shared.dyn") + A_1 = T.Buffer((16,), data=A.data) + B_1_1[threadIdx_x // 4 * 6 + threadIdx_x % 4] = A_1[threadIdx_x] + C_1_1 = T.Buffer((1,), data=C_1, scope="local") + C_1_1[0] = B_1_1[threadIdx_x // 4 * 6 + threadIdx_x % 4] + T.tvm_storage_sync("shared.dyn") + D_1_1 = T.Buffer((16,), data=D_1, scope="shared.dyn") + D_1_1[threadIdx_x] = C_1_1[0] + E_1 = T.Buffer((16,), data=E.data) + E_1[threadIdx_x] = D_1_1[threadIdx_x] + + mod = tvm.IRModule({"main": func}) + mod = tvm.tir.transform.ThreadSync("shared.dyn")(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + if __name__ == "__main__": test_thread_storage_sync() test_sync_else_branch() test_sync_read_thread_id_independent_location() + test_sync_shared_dyn()