Skip to content

Conversation

@lsy643
Copy link
Contributor

@lsy643 lsy643 commented Jul 22, 2020

In this PR, the CUDA compute funtions of get_valid_counts and nms are changed to make them work as expected.

  1. For get_valid_counts, only one thread is used for one image. I am not sure whether this is a good way
  2. For nms, there are two changes
    2.1 make box_indices to map back to the original data indices
    2.2 create rearrange_indices_out for nms when return_indices == True
  3. Test cases for gpu version ofget_valid_counts and nms are enabled now

@lsy643
Copy link
Contributor Author

lsy643 commented Jul 22, 2020

cc @yongwww @Laurawly

if target in ['cuda', 'opencl']:
return
# get_valid_count for opencl doesn't do data rearrangement
if target in ['opencl']:
Copy link
Contributor

@trevor-m trevor-m Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OpenCL shares the cuda implementation, so you can enable this test too. The CI doesn't run opencl so please test it manually.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test for opencl has been enabled.

with ib.if_scope(
tvm.tir.all(data[tid * elem_length + score_index] > score_threshold,
tvm.tir.any(id_index < 0, data[tid * elem_length + id_index] >= 0))):
atomic_add_return[0] = atomic_add(tvm.tir.call_intrin("handle", "tir.address_of",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are no longer using atomic_add, should we remove those intrinsic definitions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused atomic_add definitions have been removed.

@trevor-m
Copy link
Contributor

IIRC, data arrangement was removed from get_valid_counts to improve performance because the data arrangement would be done by NMS anyway. Does this PR maintain the performance? @Laurawly

@yongwww
Copy link
Member

yongwww commented Jul 25, 2020

@lsy643 Regarding the thread change, could you please benchmark the performance before and after your change and share the numbers?

max_threads = int(tvm.target.Target.current(
allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = batch_size * num_anchors // max_threads + 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would like to know more about the reason behind the change, perhaps share some benchmark numbers? In some scenarios, like TF MaskRCNN, a large number of boxes (num_anchors > 20000) are in one batch, multiple threads here might provide performance improvement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the gpu get_valid_counts does not need to moves valid boxes to the top of input data because of the argsort in the nms, so using the original get_valid_counts_ir ought to be better. And do you think rearrange_indices_out_ir is the right way to do it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only using one thread will regress the performance a lot. More benchmark should be shown according to previous PRs' workloads.

max_threads = int(tvm.target.Target.current(
allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = batch_size * num_anchors // max_threads + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only using one thread will regress the performance a lot. More benchmark should be shown according to previous PRs' workloads.

return ib.get()


def rearrange_indices_out(data):
Copy link
Contributor

@Laurawly Laurawly Jul 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will regress performance a lot, don't recommend.

@Laurawly
Copy link
Contributor

I'm wondering what's the purpose of this PR. Currently, there's no correctness issue with get_valid_counts and nms in end-to-end performance of object detection related models using them. And based on the benchmarks shown in this PR, there's a lot of performance gain on recent changes. And as a summarization of that PR, we know data rearrangement on the GPU is super slow.

@lsy643
Copy link
Contributor Author

lsy643 commented Jul 27, 2020

@trevor-m @yongwww @Laurawly
For the get_valid_counts part, I misunderstand it because I didn't understand argsort correctly and I have recovered to the original version, which is much faster.

For the rearrange_indices_out part, which is necessary because the result of nms is used by a strided_slice in def _nms in tensorflow frontent, I agree that the current way may regress the performance, but since we need to do data arrangement in this function, I can hardly figure out a better way to implement it.

@Laurawly
Copy link
Contributor

@trevor-m @yongwww @Laurawly
For the get_valid_counts part, I misunderstand it because I didn't understand argsort correctly and I have recovered to the original version, which is much faster.

For the rearrange_indices_out part, which is necessary because the result of nms is used by a strided_slice in def _nms in tensorflow frontent, I agree that the current way may regress the performance, but since we need to do data arrangement in this function, I can hardly figure out a better way to implement it.

Could you show some benchmark numbers regarding the changes? @yongwww could have a better comment on the tensorflow related changes. Also it seems that there's illegal memory access error based on CI.

@yongwww
Copy link
Member

yongwww commented Aug 24, 2020

@lsy643 the rearrange_indices_out part you updated looks good to me. Currently I am concerned about the thread related change, since the change might cause some performance regression, especially for scenarios with >20k inputs boxes. It would great to show some benchmark numbers for the change.

@lsy643
Copy link
Contributor Author

lsy643 commented Aug 26, 2020

@yongwww @Laurawly
Sorry for the late response, I am being quite busy at work recently. I will try to fix the error and run the benchmark by the end of this week

@lsy643
Copy link
Contributor Author

lsy643 commented Aug 30, 2020

@yongwww @Laurawly
I am quite confused about the test data used in test_non_max_suppression from tests/python/relay/test_op_level5.py.

If I understand correctly, a get_valid_count, a non_max_suppression and a strided_slice are used together as a non_max_suppression operator according to def _nms() from frontend/tensorflow.py. The get_valid_counts of the cuda version does not move valid boxes to the top of input data, while the get_valid_counts of the cpu version does the job.

Therefore, it seems to make more sense if we use different test data for the test_non_max_suppression
For example

the original data before get_valid_counts

np_data = np.array([[[0, 0.8, 1, 20, 25, 45], 
                     [1, 0.7, 30, 60, 50, 80],
                     [0, 0.4, 4, 21, 19, 40], 
                     [2, 0.9, 35, 61, 52, 79],
                     [1, 0.5, 100, 60, 70, 110]]]).astype("float32")
cpu test data after get_valid_counts

np_data = np.array([[[0, 0.8, 1, 20, 25, 45], 
                     [1, 0.7, 1, 20, 25, 45],
                     [2, 0.9, 35, 61, 52, 79],
                     [1, 0.5, 100, 60, 70, 110], 
                     [-1, -1, -1, -1, -1, -1]]]).astype("float32")

np_indices = np.array([[0, 1, 3, 4, -1]]).astype("int32")
cuda test data after get_valid_counts

np_data = np.array([[[0, 0.8, 1, 20, 25, 45], 
                     [1, 0.7, 2, 21, 26, 45],
                     [-1, -1, -1, -1, -1, -1], 
                     [2, 0.9, 35, 61, 52, 79],
                     [1, 0.5, 100, 60, 70, 110]]]).astype("float32")

np_indices = np.array([[0, 1, -1, 3, 4]]).astype("int32")

@yongwww
Copy link
Member

yongwww commented Sep 4, 2020

@lsy643 you are right, the auxiliary op get_valid_count and strided_slice are utilized to help handle TensorFlow dynamic NonMaximumSuppression. As a todo task, the cpu and gpu versions of the op are expected to behave consistently.

@lsy643
Copy link
Contributor Author

lsy643 commented Sep 11, 2020

@yongwww
I have added a test case for nms cuda version in test_op_level5.p with test data assumed getting from a get_valid_count.

Since there is no rearrange_indices_out for nms cuda version, I only compare it with the llvm verison

  1. For test data with shape [1, 5, 6]
  • cuda time: 90us
  • llvm time: 32us

2 For test data with shape [1, 20000, 6]

  • cuda time: 6230us
  • llvm time: 219209us

The inference time for llvm with large dataset is too large.

Test data I use

data_length = 20000
np_valid_count = np.array([20000]).astype("int32")
v = []
for i in range(20000):
    v.append(i)
np_indices = np.array([v]).astype("int32")

np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80],
                        [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79],
                        [1, 0.5, 100, 60, 70, 110]]]).astype("float32")
np_data = np_data.repeat(20000/5, axis=1)

The compute and schedule functions I use

    use_cuda = False
    if use_cuda:
        device = 'cuda'
        fcompute = topi.cuda.non_max_suppression
        fschedule = topi.cuda.schedule_nms
    else:
        device = 'llvm'
        fcompute = topi.vision.non_max_suppression
        fschedule = topi.generic.schedule_nms

@yongwww
Copy link
Member

yongwww commented Sep 15, 2020

@lsy643 thanks for sharing the results. What I am wondering is the latency of your change vs previous nms gpu version (even the output is not identical), and probably the perf number of your change vs TensorFlow baseline. As Leyuan mentioned above, the thread related change might cause performance regression, performance matters a lot for us, so we would like to see some perf number about this. If performance regression does exist, then it should be fixed.

@lsy643
Copy link
Contributor Author

lsy643 commented Sep 15, 2020

@yongwww The rearrange_indices_out is a newly part added to non_max_suppression from topi/cuda/nms.py, and so I compare the latency of non_max_suppression when return_indices=True and return_indices=False.

For test data with shape [1, 20000, 6]

When return_indices=True, the rearrange_indices_out is used

  • latency: 6342us

When return_indices=False, the rearrange_indices_out is used

  • latency: 5645us

@tqchen tqchen changed the base branch from master to main October 11, 2020 18:28
Copy link
Contributor

@Laurawly Laurawly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please resolve conflicts with main branch.

@Laurawly
Copy link
Contributor

Laurawly commented Dec 7, 2020

@lsy643 @yongwww @trevor-m It seems the purpose of this PR has already been made by PR #7005 which has been merged to upstream. Closing this PR for now. Please reopen if there are other issues.

@Laurawly Laurawly closed this Dec 7, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants