Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/tvm/topi/cuda/conv2d_transpose_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ def _callback(op):
##### space definition begin #####
n, f, y, x = s[conv].op.axis
rc = s[conv].op.reduce_axis[0]
cfg.define_split("tile_n", cfg.axis(n), num_outputs=4)
# TODO(@kevinthesun): Support tuning/optimization for dynamic shape.
bs = pad_data.shape[0]
n_tuning_axis = n if isinstance(bs, tvm.tir.IntImm) else 1
cfg.define_split("tile_n", cfg.axis(n_tuning_axis), num_outputs=4)
cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)
Expand All @@ -194,6 +197,8 @@ def _callback(op):

if cfg.is_fallback:
N, F, Y, X = get_const_tuple(conv.shape)
if not isinstance(N, int):
N = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test that hits this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we do have a test for this. Now I enabled all targets.

_fallback_schedule(N, F, Y, X)

##### space definition end #####
Expand Down
13 changes: 12 additions & 1 deletion python/tvm/topi/cuda/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,16 @@ def schedule_injective_from_existing(sch, out):
# bandwidth.
vector_width = 4 if out.dtype == "float16" else 1

is_dynamic_output = False
for dim in out.shape:
if not isinstance(dim, tvm.tir.IntImm):
is_dynamic_output = True
break

out_len = utils.prod(out.shape)

try:
const_size = utils.get_const_int(utils.prod(out.shape))
const_size = utils.get_const_int(out_len)
need_block_split = const_size > max_block * num_thread * vector_width
except ValueError:
need_block_split = False
Expand All @@ -61,6 +69,9 @@ def schedule_injective_from_existing(sch, out):
sch[out].bind(bx, te.thread_axis("blockIdx.x"))
sch[out].bind(tx, te.thread_axis("threadIdx.x"))
else:
# Use less threads for dynamic shape ops to avoid runtime error.
if is_dynamic_output:
num_thread //= 2
bx, tx = sch[out].split(fused, factor=num_thread)
sch[out].bind(tx, te.thread_axis("threadIdx.x"))
sch[out].bind(bx, te.thread_axis("blockIdx.x"))
Expand Down
49 changes: 46 additions & 3 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from tvm.tir import if_then_else
from .sort import argsort, argsort_thrust
from .. import tag


def cuda_atomic_add_rule(op):
Expand Down Expand Up @@ -95,7 +94,7 @@ def rearrange_indices_out_ir(data, output, valid_box_count):
with ib.new_scope():
i = te.thread_axis("blockIdx.x")
ib.scope_attr(i, "thread_extent", batch_size)
valid_idx = ib.allocate("int32", (1), name="valid_idx", scope="local")
valid_idx = ib.allocate("int32", (1,), name="valid_idx", scope="local")
valid_idx[0] = 0
with ib.for_range(0, num_anchors, name="j") as j:
with ib.if_scope(data[i, j] >= 0):
Expand Down Expand Up @@ -654,6 +653,35 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
return ib.get()


def _fetch_score_ir(data, score, axis):
"""
Fetch score from data.
This routine is required for dynamic shape nms.
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
elem_length = data.shape[2]

ib = tvm.tir.ir_builder.create()

data = ib.buffer_ptr(data)
score = ib.buffer_ptr(score)
with ib.if_scope(num_anchors > 0):
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = batch_size * num_anchors // max_threads + 1
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)

tid = bx * max_threads + tx
with ib.if_scope(tid < batch_size * num_anchors):
score[tid] = data[tid * elem_length + axis]

return ib.get()


def non_max_suppression(
data,
valid_count,
Expand Down Expand Up @@ -754,7 +782,22 @@ def non_max_suppression(
)
score_axis = score_index
score_shape = (batch_size, num_anchors)
score_tensor = te.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fine, but I'm a little surprised it's necessary. Do you have a test case that breaks the current code, or is this mostly for performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the nms workload is large like in RCNN models, general cuda injective schedule can still cause runtime error even with the improvement of this PR. It's common that any dynamic injective op can have runtime issue with current uniform cuda injective schedule.

This problem is not directly related to nms, but cuda injective schedule. Later we might need to revisit this part for gpu dynamic ops and have a better and more general solution(together with more tests).

score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", data_alignment=8)
score_tensor = te.extern(
[score_shape],
[data],
lambda ins, outs: _fetch_score_ir(
ins[0],
outs[0],
score_axis,
),
dtype=[data.dtype],
in_buffers=[data_buf],
out_buffers=[score_buf],
name="fetch_score",
tag="fetch_score",
)
target = tvm.target.Target.current()
if (
target
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,9 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int
tag="topk_gpu",
)

if isinstance(k, tvm.tir.IntImm):
k = k.value

if not isinstance(k, int) or k > 0:
beg = [0] * ndim
end = data.shape[:-1] + [k if isinstance(k, int) else tvm.te.size_var("dim")]
Expand Down
9 changes: 9 additions & 0 deletions src/runtime/contrib/thrust/thrust.cu
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
if (value_dtype == "int32") {
thrust_stable_sort_by_key<int, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<int, int64_t>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<int, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
Expand All @@ -215,6 +218,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
if (value_dtype == "int32") {
thrust_stable_sort_by_key<int64_t, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<int64_t, int64_t>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<int64_t, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
Expand All @@ -225,6 +231,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
if (value_dtype == "int32") {
thrust_stable_sort_by_key<float, int>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "int64") {
thrust_stable_sort_by_key<float, int64_t>(keys_in, values_in, keys_out, values_out,
for_scatter);
} else if (value_dtype == "float32") {
thrust_stable_sort_by_key<float, float>(keys_in, values_in, keys_out, values_out,
for_scatter);
Expand Down
17 changes: 15 additions & 2 deletions src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, In
std::vector<int> codes(arity);
runtime::TVMArgsSetter setter(values.data(), codes.data());
int idx = 0;
bool is_empty_output = false;
for (Index i = 0; i < arg_count; i++) {
if (const auto* dt_cell = args[i].as<ADTObj>()) {
for (size_t fi = 0; fi < dt_cell->size; ++fi) {
Expand All @@ -254,12 +255,24 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, In
}
} else {
auto nd_array = Downcast<NDArray>(args[i]);
// We can safely skip CallPacked if there is only one
// output and it is empty.
if (i == arg_count - 1 && output_size == 1) {
for (const auto& dim : nd_array.Shape()) {
if (!dim) {
is_empty_output = true;
break;
}
}
}
setter(idx++, nd_array);
}
}

TVMRetValue rv;
func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
if (!is_empty_output) {
TVMRetValue rv;
func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
}
}

void VirtualMachine::LoadExecutable(const Executable* exec) {
Expand Down
29 changes: 26 additions & 3 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def check_result(
for kind in ["debug", "vm"]:
targets = targets or tvm.testing.enabled_targets()
for tgt, ctx in targets:
print(tgt)
if disable_targets and tgt in disable_targets:
continue
if kind == "debug" and (only_vm or ctx.device_type != tvm.cpu().device_type):
Expand Down Expand Up @@ -199,6 +200,15 @@ def test_any_concat():
ref = np.concatenate([x_np - 3.0, y_np * 5.0], axis=0)
check_result([x_np, y_np], mod, ref)

num_inputs = 25
x = [relay.var("x", shape=(relay.Any(),), dtype="float32") for _ in range(num_inputs)]
z = relay.op.concatenate(x, axis=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this tests the injective schedule 👍

mod = tvm.IRModule()
mod["main"] = relay.Function(x, z)
x_np = [np.random.uniform(size=(1,)).astype("float32") for _ in range(num_inputs)]
ref = np.concatenate(x_np, axis=0)
check_result(x_np, mod, ref)


def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newshape=False):
x = relay.var("x", shape=x_shape, dtype="float32")
Expand Down Expand Up @@ -572,9 +582,7 @@ def verify_any_conv2d_transpose_nchw(
mod["main"] = relay.Function([data, kernel], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
kernel_np = np.random.uniform(size=kernel_shape).astype(dtype)
check_result(
[data_np, kernel_np], mod, ref_out_shape, assert_shape=True, targets=[("llvm", tvm.cpu())]
)
check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True)


# TODO(@kevinthesun): Support dynamic input height and width.
Expand Down Expand Up @@ -1430,6 +1438,21 @@ def test_non_max_suppression():
disable_targets=["nvptx"],
)

np_data = np.zeros((1, 0, 6)).astype("float32")
np_valid_count = np.array([0]).astype("int32")
np_indices = np.zeros((1, 0)).astype("int32")
np_max_output_size = -1
np_indices_result = np.zeros((1, 0))
np_valid_box_count = np.array([[0]]).astype("int32")

check_result(
[np_data, np_valid_count, np_indices, np_max_output_size],
mod,
[np_indices_result, np_valid_box_count],
only_vm=False,
disable_targets=["nvptx"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tests the empty output VM change 👍
Why disable nvptx?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is issue causing segfault from dynamic nms for nvptx, and generally we need thrust for any dynamic shape sorting. For now nvptx is not ready for these operations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I'm trying to fix the default sort kernel in #7099, if you want to take a look

)


if __name__ == "__main__":
pytest.main([__file__])