From 240df3e308162e87895603b1ba4b183a3e7fbd8c Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Wed, 16 Dec 2020 01:30:49 +0000 Subject: [PATCH 1/4] Fix GPU dynamic op schedules --- python/tvm/topi/cuda/conv2d_transpose_nchw.py | 7 +++++- python/tvm/topi/cuda/injective.py | 13 +++++++++- python/tvm/topi/cuda/sort.py | 3 +++ src/runtime/vm/vm.cc | 17 +++++++++++-- tests/python/relay/test_any.py | 24 +++++++++++++++++++ 5 files changed, 60 insertions(+), 4 deletions(-) diff --git a/python/tvm/topi/cuda/conv2d_transpose_nchw.py b/python/tvm/topi/cuda/conv2d_transpose_nchw.py index 609d1acc78bd..3b704170a2e9 100644 --- a/python/tvm/topi/cuda/conv2d_transpose_nchw.py +++ b/python/tvm/topi/cuda/conv2d_transpose_nchw.py @@ -179,7 +179,10 @@ def _callback(op): ##### space definition begin ##### n, f, y, x = s[conv].op.axis rc = s[conv].op.reduce_axis[0] - cfg.define_split("tile_n", cfg.axis(n), num_outputs=4) + # TODO(@kevinthesun): Support tuning/optimization for dynamic shape. + bs = pad_data.shape[0] + n_tuning_axis = n if isinstance(bs, tvm.tir.IntImm) else 1 + cfg.define_split("tile_n", cfg.axis(n_tuning_axis), num_outputs=4) cfg.define_split("tile_f", cfg.axis(f), num_outputs=4) cfg.define_split("tile_y", cfg.axis(y), num_outputs=4) cfg.define_split("tile_x", cfg.axis(x), num_outputs=4) @@ -194,6 +197,8 @@ def _callback(op): if cfg.is_fallback: N, F, Y, X = get_const_tuple(conv.shape) + if not isinstance(N, int): + N = 1 _fallback_schedule(N, F, Y, X) ##### space definition end ##### diff --git a/python/tvm/topi/cuda/injective.py b/python/tvm/topi/cuda/injective.py index 60fb12e4975e..7f0790aebf4d 100644 --- a/python/tvm/topi/cuda/injective.py +++ b/python/tvm/topi/cuda/injective.py @@ -44,8 +44,16 @@ def schedule_injective_from_existing(sch, out): # bandwidth. vector_width = 4 if out.dtype == "float16" else 1 + is_dynamic_output = False + for dim in out.shape: + if not isinstance(dim, tvm.tir.IntImm): + is_dynamic_output = True + break + + out_len = utils.prod(out.shape) + try: - const_size = utils.get_const_int(utils.prod(out.shape)) + const_size = utils.get_const_int(out_len) need_block_split = const_size > max_block * num_thread * vector_width except ValueError: need_block_split = False @@ -61,6 +69,9 @@ def schedule_injective_from_existing(sch, out): sch[out].bind(bx, te.thread_axis("blockIdx.x")) sch[out].bind(tx, te.thread_axis("threadIdx.x")) else: + # Use less threads for dynamic shape ops to avoid runtime error. + if is_dynamic_output: + num_thread //= 2 bx, tx = sch[out].split(fused, factor=num_thread) sch[out].bind(tx, te.thread_axis("threadIdx.x")) sch[out].bind(bx, te.thread_axis("blockIdx.x")) diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 329f0fb897e5..e4e7c53e9ba5 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -565,6 +565,9 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int tag="topk_gpu", ) + if isinstance(k, tvm.tir.IntImm): + k = k.value + if not isinstance(k, int) or k > 0: beg = [0] * ndim end = data.shape[:-1] + [k if isinstance(k, int) else tvm.te.size_var("dim")] diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 473b5d759272..3f890baf52c0 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -245,6 +245,7 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, In std::vector codes(arity); runtime::TVMArgsSetter setter(values.data(), codes.data()); int idx = 0; + bool is_empty_output = false; for (Index i = 0; i < arg_count; i++) { if (const auto* dt_cell = args[i].as()) { for (size_t fi = 0; fi < dt_cell->size; ++fi) { @@ -254,12 +255,24 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, In } } else { auto nd_array = Downcast(args[i]); + // We can safely skip CallPacked if there is only one + // output and it is empty. + if (i == arg_count - 1 && output_size == 1) { + for (const auto& dim : nd_array.Shape()) { + if (!dim) { + is_empty_output = true; + break; + } + } + } setter(idx++, nd_array); } } - TVMRetValue rv; - func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); + if (!is_empty_output) { + TVMRetValue rv; + func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); + } } void VirtualMachine::LoadExecutable(const Executable* exec) { diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index da029e1d77ed..38dfa0106f19 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -199,6 +199,15 @@ def test_any_concat(): ref = np.concatenate([x_np - 3.0, y_np * 5.0], axis=0) check_result([x_np, y_np], mod, ref) + num_inputs = 25 + x = [relay.var("x", shape=(relay.Any(),), dtype="float32") for _ in range(num_inputs)] + z = relay.op.concatenate(x, axis=0) + mod = tvm.IRModule() + mod["main"] = relay.Function(x, z) + x_np = [np.random.uniform(size=(1,)).astype("float32") for _ in range(num_inputs)] + ref = np.concatenate(x_np, axis=0) + check_result(x_np, mod, ref) + def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newshape=False): x = relay.var("x", shape=x_shape, dtype="float32") @@ -1430,6 +1439,21 @@ def test_non_max_suppression(): disable_targets=["nvptx"], ) + np_data = np.zeros((1, 0, 6)).astype("float32") + np_valid_count = np.array([0]).astype("int32") + np_indices = np.zeros((1, 0)).astype("int32") + np_max_output_size = -1 + np_indices_result = np.zeros((1, 0)) + np_valid_box_count = np.array([[0]]).astype("int32") + + check_result( + [np_data, np_valid_count, np_indices, np_max_output_size], + mod, + [np_indices_result, np_valid_box_count], + only_vm=False, + disable_targets=["nvptx"], + ) + if __name__ == "__main__": pytest.main([__file__]) From 5cf404327b3894808a70fb366a9b3a32c83e8474 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Wed, 16 Dec 2020 01:31:26 +0000 Subject: [PATCH 2/4] Fix dynamic shape nms --- python/tvm/topi/cuda/nms.py | 67 ++++++++++++++++++++++++++++++------- 1 file changed, 55 insertions(+), 12 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index d0915d9aa55f..a33a742b92ba 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -22,7 +22,6 @@ from tvm.tir import if_then_else from .sort import argsort, argsort_thrust -from .. import tag def cuda_atomic_add_rule(op): @@ -95,23 +94,23 @@ def rearrange_indices_out_ir(data, output, valid_box_count): with ib.new_scope(): i = te.thread_axis("blockIdx.x") ib.scope_attr(i, "thread_extent", batch_size) - valid_idx = ib.allocate("int32", (1), name="valid_idx", scope="local") - valid_idx[0] = 0 + valid_idx = ib.allocate("int32", (batch_size,), name="valid_idx", scope="local") + valid_idx[i] = 0 with ib.for_range(0, num_anchors, name="j") as j: with ib.if_scope(data[i, j] >= 0): with ib.if_scope(data[i, j] > num_anchors): - output[i, valid_idx[0]] = 0 - valid_idx[0] = valid_idx[0] + 1 + output[i, valid_idx[i]] = 0 + valid_idx[i] = valid_idx[i] + 1 with ib.else_scope(): - output[i, valid_idx[0]] = data[i, j] - valid_idx[0] = valid_idx[0] + 1 + output[i, valid_idx[i]] = data[i, j] + valid_idx[i] = valid_idx[i] + 1 with ib.else_scope(): with ib.if_scope(data[i, j] < -num_anchors): - output[i, valid_idx[0]] = 0 - valid_idx[0] = valid_idx[0] + 1 - with ib.if_scope(j >= valid_idx[0]): + output[i, valid_idx[i]] = 0 + valid_idx[i] = valid_idx[i] + 1 + with ib.if_scope(j >= valid_idx[i]): output[i, j] = -1 - valid_box_count[i, 0] = valid_idx[0] + valid_box_count[i, 0] = valid_idx[i] return ib.get() @@ -654,6 +653,35 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): return ib.get() +def _fetch_score_ir(data, score, axis): + """ + Fetch score from data. + This routine is required for dynamic shape nms. + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + elem_length = data.shape[2] + + ib = tvm.tir.ir_builder.create() + + data = ib.buffer_ptr(data) + score = ib.buffer_ptr(score) + with ib.if_scope(num_anchors > 0): + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = batch_size * num_anchors // max_threads + 1 + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + + tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size * num_anchors): + score[tid] = data[tid * elem_length + axis] + + return ib.get() + + def non_max_suppression( data, valid_count, @@ -754,7 +782,22 @@ def non_max_suppression( ) score_axis = score_index score_shape = (batch_size, num_anchors) - score_tensor = te.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE) + data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", data_alignment=8) + score_tensor = te.extern( + [score_shape], + [data], + lambda ins, outs: _fetch_score_ir( + ins[0], + outs[0], + score_axis, + ), + dtype=[data.dtype], + in_buffers=[data_buf], + out_buffers=[score_buf], + name="fetch_score", + tag="fetch_score", + ) target = tvm.target.Target.current() if ( target From 183b88ab49943426d0a01941e11a77bfdc8ea206 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Wed, 16 Dec 2020 19:04:56 +0000 Subject: [PATCH 3/4] Fix --- python/tvm/topi/cuda/nms.py | 20 ++++++++++---------- src/runtime/contrib/thrust/thrust.cu | 9 +++++++++ tests/python/relay/test_any.py | 3 ++- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index a33a742b92ba..273397071219 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -94,23 +94,23 @@ def rearrange_indices_out_ir(data, output, valid_box_count): with ib.new_scope(): i = te.thread_axis("blockIdx.x") ib.scope_attr(i, "thread_extent", batch_size) - valid_idx = ib.allocate("int32", (batch_size,), name="valid_idx", scope="local") - valid_idx[i] = 0 + valid_idx = ib.allocate("int32", (1,), name="valid_idx", scope="local") + valid_idx[0] = 0 with ib.for_range(0, num_anchors, name="j") as j: with ib.if_scope(data[i, j] >= 0): with ib.if_scope(data[i, j] > num_anchors): - output[i, valid_idx[i]] = 0 - valid_idx[i] = valid_idx[i] + 1 + output[i, valid_idx[0]] = 0 + valid_idx[0] = valid_idx[0] + 1 with ib.else_scope(): - output[i, valid_idx[i]] = data[i, j] - valid_idx[i] = valid_idx[i] + 1 + output[i, valid_idx[0]] = data[i, j] + valid_idx[0] = valid_idx[0] + 1 with ib.else_scope(): with ib.if_scope(data[i, j] < -num_anchors): - output[i, valid_idx[i]] = 0 - valid_idx[i] = valid_idx[i] + 1 - with ib.if_scope(j >= valid_idx[i]): + output[i, valid_idx[0]] = 0 + valid_idx[0] = valid_idx[0] + 1 + with ib.if_scope(j >= valid_idx[0]): output[i, j] = -1 - valid_box_count[i, 0] = valid_idx[i] + valid_box_count[i, 0] = valid_idx[0] return ib.get() diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 8ccefc5ee7d2..dddbb043fddc 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -205,6 +205,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") if (value_dtype == "int32") { thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, for_scatter); + } else if (value_dtype == "int64") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); } else if (value_dtype == "float32") { thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, for_scatter); @@ -215,6 +218,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") if (value_dtype == "int32") { thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, for_scatter); + } else if (value_dtype == "int64") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); } else if (value_dtype == "float32") { thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, for_scatter); @@ -225,6 +231,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key") if (value_dtype == "int32") { thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, for_scatter); + } else if (value_dtype == "int64") { + thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, + for_scatter); } else if (value_dtype == "float32") { thrust_stable_sort_by_key(keys_in, values_in, keys_out, values_out, for_scatter); diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 38dfa0106f19..4e9dfb622af2 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -54,6 +54,7 @@ def check_result( for kind in ["debug", "vm"]: targets = targets or tvm.testing.enabled_targets() for tgt, ctx in targets: + print(tgt) if disable_targets and tgt in disable_targets: continue if kind == "debug" and (only_vm or ctx.device_type != tvm.cpu().device_type): @@ -582,7 +583,7 @@ def verify_any_conv2d_transpose_nchw( data_np = np.random.uniform(size=static_data_shape).astype(dtype) kernel_np = np.random.uniform(size=kernel_shape).astype(dtype) check_result( - [data_np, kernel_np], mod, ref_out_shape, assert_shape=True, targets=[("llvm", tvm.cpu())] + [data_np, kernel_np], mod, ref_out_shape, assert_shape=True ) From b51ac7b3d8c6894ec677d9fd3a5713148d622bc4 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Wed, 16 Dec 2020 19:19:42 +0000 Subject: [PATCH 4/4] Fix test format --- tests/python/relay/test_any.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 4e9dfb622af2..dfc03c0cf6b1 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -582,9 +582,7 @@ def verify_any_conv2d_transpose_nchw( mod["main"] = relay.Function([data, kernel], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) kernel_np = np.random.uniform(size=kernel_shape).astype(dtype) - check_result( - [data_np, kernel_np], mod, ref_out_shape, assert_shape=True - ) + check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True) # TODO(@kevinthesun): Support dynamic input height and width.