From c9998fcad1f0967da065ca60b4c8f7746c37a0ee Mon Sep 17 00:00:00 2001 From: valmat07 Date: Tue, 2 May 2023 13:08:12 +0300 Subject: [PATCH 1/7] fixed the call of the minimum function in the schedule for cuda --- python/tvm/topi/cuda/scatter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 39ef5a5a42ca..7f5fb8aa8770 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -227,8 +227,8 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): fused_shape *= i max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - tdim = min(max_threads, fused_updates_dimension) + tdim = tvm.tir.min(max_threads, fused_updates_dimension) with ib.new_scope(): bdim = ceil_div(fused_shape, tdim) bx = te.thread_axis("blockIdx.x") From 35761837b2db0400f5daf3ceec4a6a36b5e5baf8 Mon Sep 17 00:00:00 2001 From: valmat07 Date: Tue, 2 May 2023 18:07:07 +0300 Subject: [PATCH 2/7] add test for scatter_nd --- tests/python/relay/test_any.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 443637276e24..4b587b06e9aa 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -2148,6 +2148,29 @@ def verify_scatter_nd(data_np, indices_np, updates_np, ref_res): verify_scatter_nd(data, indices, updates, out) +@tvm.testing.uses_gpu +def test_scatter_nd_any_updates(): + def verify_scatter_nd_any_updates(data_np, indices_np, updates_np, ref_res): + indices_shape = (2, relay.Any()) + updates_shape = (relay.Any(), relay.Any()) + data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype)) + indices = relay.var("indices", relay.TensorType(indices_shape, str(indices_np.dtype))) + updates = relay.var("updates", relay.TensorType(updates_shape, str(updates_np.dtype))) + + out = relay.op.scatter_nd(data, indices, updates, "add") + + mod = tvm.IRModule() + mod["main"] = relay.Function([data, indices, updates], out) + + check_result([data_np, indices_np, updates_np], mod, [ref_res]) + + data = np.zeros((3, 3)).astype("int64") + indices = np.array([[1, 1, 2, 1], [0, 1, 2, 1]]) + updates = np.array([[2, 3], [1, 1]]) + out = np.array([[0, 0, 0], [0, 0, 0], [2, 3, 1]]) + verify_scatter_nd_any_updates(data, indices, updates, out) + + @tvm.testing.uses_gpu def test_gather(): def verify_gather(data_shape, indices_shape, data_shape_np, indices_shape_np, axis): From 8fb1b43d005a8562800afd21fff15828e807aed9 Mon Sep 17 00:00:00 2001 From: valmat07 Date: Wed, 3 May 2023 14:48:57 +0300 Subject: [PATCH 3/7] update test only for cuda target --- tests/python/relay/test_any.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 4b587b06e9aa..578b180695d9 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -2162,7 +2162,7 @@ def verify_scatter_nd_any_updates(data_np, indices_np, updates_np, ref_res): mod = tvm.IRModule() mod["main"] = relay.Function([data, indices, updates], out) - check_result([data_np, indices_np, updates_np], mod, [ref_res]) + check_result([data_np, indices_np, updates_np], mod, [ref_res], targets=[('cuda', tvm.cuda(0))]) data = np.zeros((3, 3)).astype("int64") indices = np.array([[1, 1, 2, 1], [0, 1, 2, 1]]) From e63d59789934112f338844fef8675076612225cc Mon Sep 17 00:00:00 2001 From: valmat07 Date: Wed, 3 May 2023 18:00:45 +0300 Subject: [PATCH 4/7] fix lint --- tests/python/relay/test_any.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 578b180695d9..277b0f8f2ba8 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -2162,7 +2162,9 @@ def verify_scatter_nd_any_updates(data_np, indices_np, updates_np, ref_res): mod = tvm.IRModule() mod["main"] = relay.Function([data, indices, updates], out) - check_result([data_np, indices_np, updates_np], mod, [ref_res], targets=[('cuda', tvm.cuda(0))]) + check_result( + [data_np, indices_np, updates_np], mod, [ref_res], targets=[('cuda', tvm.cuda(0))] + ) data = np.zeros((3, 3)).astype("int64") indices = np.array([[1, 1, 2, 1], [0, 1, 2, 1]]) From cc2988f458d45294314e003430683d8fc8606d7c Mon Sep 17 00:00:00 2001 From: valmat07 Date: Thu, 4 May 2023 13:50:46 +0300 Subject: [PATCH 5/7] update test --- tests/python/relay/test_any.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 277b0f8f2ba8..38a2c1833d03 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -59,6 +59,7 @@ def check_result( continue if kind == "debug" and (only_vm or dev.device_type != tvm.cpu().device_type): continue + print(tgt) result = relay.create_executor(kind, mod=mod, device=dev, target=tgt).evaluate()(*args) if isinstance(result, tvm.runtime.container.ADT): result = [r.numpy() for r in result] @@ -2152,7 +2153,7 @@ def verify_scatter_nd(data_np, indices_np, updates_np, ref_res): def test_scatter_nd_any_updates(): def verify_scatter_nd_any_updates(data_np, indices_np, updates_np, ref_res): indices_shape = (2, relay.Any()) - updates_shape = (relay.Any(), relay.Any()) + updates_shape = (2, relay.Any()) data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype)) indices = relay.var("indices", relay.TensorType(indices_shape, str(indices_np.dtype))) updates = relay.var("updates", relay.TensorType(updates_shape, str(updates_np.dtype))) @@ -2163,13 +2164,13 @@ def verify_scatter_nd_any_updates(data_np, indices_np, updates_np, ref_res): mod["main"] = relay.Function([data, indices, updates], out) check_result( - [data_np, indices_np, updates_np], mod, [ref_res], targets=[('cuda', tvm.cuda(0))] + [data_np, indices_np, updates_np], mod, [ref_res], only_vm=True ) data = np.zeros((3, 3)).astype("int64") - indices = np.array([[1, 1, 2, 1], [0, 1, 2, 1]]) - updates = np.array([[2, 3], [1, 1]]) - out = np.array([[0, 0, 0], [0, 0, 0], [2, 3, 1]]) + indices = np.array([[1, 1], [0, 1]]) + updates = np.array([[2, 2], [1, 1]]) + out = np.array([[0, 0, 0], [0, 0, 0], [2, 2, 1]]) verify_scatter_nd_any_updates(data, indices, updates, out) From ab219de2146d607b7a3c51f06aeacdcda9303307 Mon Sep 17 00:00:00 2001 From: valmat07 Date: Wed, 10 May 2023 19:59:49 +0300 Subject: [PATCH 6/7] fix lint --- tests/python/relay/test_any.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 38a2c1833d03..3899a9068bbc 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -2163,9 +2163,7 @@ def verify_scatter_nd_any_updates(data_np, indices_np, updates_np, ref_res): mod = tvm.IRModule() mod["main"] = relay.Function([data, indices, updates], out) - check_result( - [data_np, indices_np, updates_np], mod, [ref_res], only_vm=True - ) + check_result([data_np, indices_np, updates_np], mod, [ref_res], only_vm=True) data = np.zeros((3, 3)).astype("int64") indices = np.array([[1, 1], [0, 1]]) From 00ee9fb5e7c110326f2e7eabca629fcd091dd278 Mon Sep 17 00:00:00 2001 From: valmat07 Date: Thu, 11 May 2023 14:38:59 +0300 Subject: [PATCH 7/7] apply comments --- tests/python/relay/test_any.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 3899a9068bbc..3cf4e5310669 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -59,7 +59,6 @@ def check_result( continue if kind == "debug" and (only_vm or dev.device_type != tvm.cpu().device_type): continue - print(tgt) result = relay.create_executor(kind, mod=mod, device=dev, target=tgt).evaluate()(*args) if isinstance(result, tvm.runtime.container.ADT): result = [r.numpy() for r in result]