From 90e90387c2906f80e0b092cd525bfdc220be50f0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 22 Jul 2021 12:17:45 +0900 Subject: [PATCH 1/3] Fix storage_access not visiting else branch --- src/tir/transforms/storage_access.cc | 1 + .../test_tir_transform_thread_sync.py | 64 +++++++++++++++---- 2 files changed, 52 insertions(+), 13 deletions(-) diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 9dae0006facd..0567c8613fcd 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -172,6 +172,7 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) { scope_.pop_back(); if (op->else_case.defined()) { scope_.push_back(std::vector()); + this->VisitStmt(op->else_case); auto v = Summarize(std::move(scope_.back()), nullptr); scope_.pop_back(); s.access.insert(s.access.end(), v.begin(), v.end()); diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 7fff6a804e4a..ffdf4b5916c4 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -19,6 +19,21 @@ import tvm.testing +def run_passes(inputs, stmt): + func = tvm.te.schedule.SchedulePostProcToPrimFunc(inputs, stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + + cuda_target = tvm.target.Target("cuda") + + mod = tvm.tir.transform.Apply( + lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}) + )(mod) + + mod = tvm.tir.transform.SplitHostDevice()(mod) + return tvm.tir.transform.ThreadSync("shared")(mod) + + @tvm.testing.requires_cuda def test_thread_storage_sync(): m = te.size_var("m") @@ -38,23 +53,46 @@ def test_thread_storage_sync(): assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) - mod = tvm.IRModule.from_expr(func) - mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) + mod = run_passes([A, A2], stmt) + f = mod["test_kernel0"] + body_list = tvm.tir.stmt_list(f.body.body.body) + assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")) - cuda_target = tvm.target.Target("cuda") - mod = tvm.tir.transform.Apply( - lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}) - )(mod._move()) +@tvm.testing.requires_cuda +def test_sync_else_branch(): + def ir(A, B): + ib = tvm.tir.ir_builder.create() + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) - fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"] - mod = tvm.IRModule.from_expr(fdevice) - cuda_target = tvm.target.Target("cuda") - f = tvm.tir.transform.ThreadSync("shared")(mod)["test_kernel0"] - body_list = tvm.tir.stmt_list(f.body.body.body) - assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", 1) + + local = ib.allocate(A.dtype, (8,), name="buf_local", scope="local") + shared = ib.allocate(A.dtype, (8,), name="buf_shared", scope="shared") + + with ib.for_range(0, 8) as i: + with ib.if_scope(Aptr[i] < 0): + local[i] = Aptr[i] + with ib.else_scope(): + shared[i] = Aptr[i] + + with ib.for_range(0, 8) as i: + with ib.if_scope(Aptr[i] < 0): + Bptr[i] = local[i] + with ib.else_scope(): + Bptr[i] = shared[i] + + return ib.get() + + A = tvm.tir.decl_buffer((8,), "float32") + B = tvm.tir.decl_buffer((8,), "float32") + stmt = ir(A, B) + mod = run_passes([A, B], stmt) + assert "@tir.tvm_storage_sync" in str(mod) if __name__ == "__main__": test_thread_storage_sync() + test_sync_else_branch() From 3bb433e452a7369c888e3dd00f3faf613bb3b96c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 22 Jul 2021 12:57:52 +0900 Subject: [PATCH 2/3] fix conflict with #8516 in the test --- tests/python/unittest/test_tir_transform_thread_sync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index ffdf4b5916c4..4e42187b366c 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -55,7 +55,7 @@ def test_thread_storage_sync(): mod = run_passes([A, A2], stmt) f = mod["test_kernel0"] - body_list = tvm.tir.stmt_list(f.body.body.body) + body_list = tvm.tir.stmt_list(f.body.body.body.body) assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")) From 751d7c8cf359436eeee14593922aae4d058e0074 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 30 Jul 2021 05:25:17 +0900 Subject: [PATCH 3/3] update thread sync test following #8516 update --- tests/python/unittest/test_tir_transform_thread_sync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 4e42187b366c..ffdf4b5916c4 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -55,7 +55,7 @@ def test_thread_storage_sync(): mod = run_passes([A, A2], stmt) f = mod["test_kernel0"] - body_list = tvm.tir.stmt_list(f.body.body.body.body) + body_list = tvm.tir.stmt_list(f.body.body.body) assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync"))