-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[TOPI] Fix GPU Dynamic Op Schedule #7117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,6 @@ | |
|
|
||
| from tvm.tir import if_then_else | ||
| from .sort import argsort, argsort_thrust | ||
| from .. import tag | ||
|
|
||
|
|
||
| def cuda_atomic_add_rule(op): | ||
|
|
@@ -95,7 +94,7 @@ def rearrange_indices_out_ir(data, output, valid_box_count): | |
| with ib.new_scope(): | ||
| i = te.thread_axis("blockIdx.x") | ||
| ib.scope_attr(i, "thread_extent", batch_size) | ||
| valid_idx = ib.allocate("int32", (1), name="valid_idx", scope="local") | ||
| valid_idx = ib.allocate("int32", (1,), name="valid_idx", scope="local") | ||
| valid_idx[0] = 0 | ||
| with ib.for_range(0, num_anchors, name="j") as j: | ||
| with ib.if_scope(data[i, j] >= 0): | ||
|
|
@@ -654,6 +653,35 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): | |
| return ib.get() | ||
|
|
||
|
|
||
| def _fetch_score_ir(data, score, axis): | ||
| """ | ||
| Fetch score from data. | ||
| This routine is required for dynamic shape nms. | ||
| """ | ||
| batch_size = data.shape[0] | ||
| num_anchors = data.shape[1] | ||
| elem_length = data.shape[2] | ||
|
|
||
| ib = tvm.tir.ir_builder.create() | ||
|
|
||
| data = ib.buffer_ptr(data) | ||
| score = ib.buffer_ptr(score) | ||
| with ib.if_scope(num_anchors > 0): | ||
| max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) | ||
| nthread_tx = max_threads | ||
| nthread_bx = batch_size * num_anchors // max_threads + 1 | ||
| tx = te.thread_axis("threadIdx.x") | ||
| bx = te.thread_axis("blockIdx.x") | ||
| ib.scope_attr(tx, "thread_extent", nthread_tx) | ||
| ib.scope_attr(bx, "thread_extent", nthread_bx) | ||
|
|
||
| tid = bx * max_threads + tx | ||
| with ib.if_scope(tid < batch_size * num_anchors): | ||
| score[tid] = data[tid * elem_length + axis] | ||
|
|
||
| return ib.get() | ||
|
|
||
|
|
||
| def non_max_suppression( | ||
| data, | ||
| valid_count, | ||
|
|
@@ -754,7 +782,22 @@ def non_max_suppression( | |
| ) | ||
| score_axis = score_index | ||
| score_shape = (batch_size, num_anchors) | ||
| score_tensor = te.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE) | ||
| data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks fine, but I'm a little surprised it's necessary. Do you have a test case that breaks the current code, or is this mostly for performance?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When the nms workload is large like in RCNN models, general cuda injective schedule can still cause runtime error even with the improvement of this PR. It's common that any dynamic injective op can have runtime issue with current uniform cuda injective schedule. This problem is not directly related to nms, but cuda injective schedule. Later we might need to revisit this part for gpu dynamic ops and have a better and more general solution(together with more tests). |
||
| score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", data_alignment=8) | ||
| score_tensor = te.extern( | ||
| [score_shape], | ||
| [data], | ||
| lambda ins, outs: _fetch_score_ir( | ||
| ins[0], | ||
| outs[0], | ||
| score_axis, | ||
| ), | ||
| dtype=[data.dtype], | ||
| in_buffers=[data_buf], | ||
| out_buffers=[score_buf], | ||
| name="fetch_score", | ||
| tag="fetch_score", | ||
| ) | ||
| target = tvm.target.Target.current() | ||
| if ( | ||
| target | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,6 +54,7 @@ def check_result( | |
| for kind in ["debug", "vm"]: | ||
| targets = targets or tvm.testing.enabled_targets() | ||
| for tgt, ctx in targets: | ||
| print(tgt) | ||
| if disable_targets and tgt in disable_targets: | ||
| continue | ||
| if kind == "debug" and (only_vm or ctx.device_type != tvm.cpu().device_type): | ||
|
|
@@ -199,6 +200,15 @@ def test_any_concat(): | |
| ref = np.concatenate([x_np - 3.0, y_np * 5.0], axis=0) | ||
| check_result([x_np, y_np], mod, ref) | ||
|
|
||
| num_inputs = 25 | ||
| x = [relay.var("x", shape=(relay.Any(),), dtype="float32") for _ in range(num_inputs)] | ||
| z = relay.op.concatenate(x, axis=0) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this tests the injective schedule 👍 |
||
| mod = tvm.IRModule() | ||
| mod["main"] = relay.Function(x, z) | ||
| x_np = [np.random.uniform(size=(1,)).astype("float32") for _ in range(num_inputs)] | ||
| ref = np.concatenate(x_np, axis=0) | ||
| check_result(x_np, mod, ref) | ||
|
|
||
|
|
||
| def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newshape=False): | ||
| x = relay.var("x", shape=x_shape, dtype="float32") | ||
|
|
@@ -572,9 +582,7 @@ def verify_any_conv2d_transpose_nchw( | |
| mod["main"] = relay.Function([data, kernel], y) | ||
| data_np = np.random.uniform(size=static_data_shape).astype(dtype) | ||
| kernel_np = np.random.uniform(size=kernel_shape).astype(dtype) | ||
| check_result( | ||
| [data_np, kernel_np], mod, ref_out_shape, assert_shape=True, targets=[("llvm", tvm.cpu())] | ||
| ) | ||
| check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True) | ||
|
|
||
|
|
||
| # TODO(@kevinthesun): Support dynamic input height and width. | ||
|
|
@@ -1430,6 +1438,21 @@ def test_non_max_suppression(): | |
| disable_targets=["nvptx"], | ||
| ) | ||
|
|
||
| np_data = np.zeros((1, 0, 6)).astype("float32") | ||
| np_valid_count = np.array([0]).astype("int32") | ||
| np_indices = np.zeros((1, 0)).astype("int32") | ||
| np_max_output_size = -1 | ||
| np_indices_result = np.zeros((1, 0)) | ||
| np_valid_box_count = np.array([[0]]).astype("int32") | ||
|
|
||
| check_result( | ||
| [np_data, np_valid_count, np_indices, np_max_output_size], | ||
| mod, | ||
| [np_indices_result, np_valid_box_count], | ||
| only_vm=False, | ||
| disable_targets=["nvptx"], | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This tests the empty output VM change 👍
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is issue causing segfault from dynamic nms for nvptx, and generally we need thrust for any dynamic shape sorting. For now nvptx is not ready for these operations.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense. I'm trying to fix the default sort kernel in #7099, if you want to take a look |
||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__]) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a test that hits this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we do have a test for this. Now I enabled all targets.