diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 39ef5a5a42ca..7f5fb8aa8770 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -227,8 +227,8 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): fused_shape *= i max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) - tdim = min(max_threads, fused_updates_dimension) + tdim = tvm.tir.min(max_threads, fused_updates_dimension) with ib.new_scope(): bdim = ceil_div(fused_shape, tdim) bx = te.thread_axis("blockIdx.x") diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 443637276e24..3cf4e5310669 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -2148,6 +2148,29 @@ def verify_scatter_nd(data_np, indices_np, updates_np, ref_res): verify_scatter_nd(data, indices, updates, out) +@tvm.testing.uses_gpu +def test_scatter_nd_any_updates(): + def verify_scatter_nd_any_updates(data_np, indices_np, updates_np, ref_res): + indices_shape = (2, relay.Any()) + updates_shape = (2, relay.Any()) + data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype)) + indices = relay.var("indices", relay.TensorType(indices_shape, str(indices_np.dtype))) + updates = relay.var("updates", relay.TensorType(updates_shape, str(updates_np.dtype))) + + out = relay.op.scatter_nd(data, indices, updates, "add") + + mod = tvm.IRModule() + mod["main"] = relay.Function([data, indices, updates], out) + + check_result([data_np, indices_np, updates_np], mod, [ref_res], only_vm=True) + + data = np.zeros((3, 3)).astype("int64") + indices = np.array([[1, 1], [0, 1]]) + updates = np.array([[2, 2], [1, 1]]) + out = np.array([[0, 0, 0], [0, 0, 0], [2, 2, 1]]) + verify_scatter_nd_any_updates(data, indices, updates, out) + + @tvm.testing.uses_gpu def test_gather(): def verify_gather(data_shape, indices_shape, data_shape_np, indices_shape_np, axis):