From 2df53d431affccd283512f9b55a781ddeea76af8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 30 May 2021 12:15:31 +0900 Subject: [PATCH 01/17] import from branch commit c86bcf48fa6acd19647a7a096b9e1a5d4e56cc74 Merge: 0fa88051b da75b2a52 Author: Masahiro Masuda Date: Sun May 30 12:13:29 2021 +0900 Merge branch 'tmp' into all_class_nms_tf commit 0fa88051b3d30337674a07e579b78e8cb254cd66 Author: Masahiro Masuda Date: Sun May 30 06:24:57 2021 +0900 Revert "handling case when num detections is smaller than max_total_size" This reverts commit 61e70b82f338300224b22f4d6bdda349e7aa5aca. commit 67251504c652e36106718617c5ae8b42c61deffc Author: Masahiro Masuda Date: Sun May 30 05:43:06 2021 +0900 handling case when num detections is smaller than max_total_size commit 39549aa25267617671ca2a82ca517442065afe97 Author: Masahiro Masuda Date: Sun May 30 05:32:37 2021 +0900 simplify frontend commit ca9470ba68e68c81902b0a3bad4bf5b5f0aa311e Author: Masahiro Masuda Date: Sun May 30 05:25:13 2021 +0900 update op definition commit 47bdef9e0fcdbab4671dd46044be5acac24b2f2b Author: Masahiro Masuda Date: Sat May 29 19:47:04 2021 +0900 remove unnecessary mask commit 445a7daf1afb794be5f03473c70b172f06556d05 Author: Masahiro Masuda Date: Sat May 29 16:54:19 2021 +0900 remove in_buffer commit 71879b115b3bfe8087b73618bdab16fd61fbed86 Author: Masahiro Masuda Date: Sat May 29 16:48:22 2021 +0900 minor fix commit 72e055a721ee7d698e7b3a3f58ab074ab78b57b2 Author: Masahiro Masuda Date: Sat May 29 16:45:37 2021 +0900 make it more readable commit a1fe7c46d6bb77da51de24b46bec881ed19d4cb3 Author: Masahiro Masuda Date: Sat May 29 16:44:14 2021 +0900 clean up commit 0c659bf27f9b90dd53455abd0d24b42f86e802bb Author: Masahiro Masuda Date: Sat May 29 16:33:54 2021 +0900 improve sort on cpu commit 480f6b782427dd46514efbf7028de4fb9f5ff9aa Author: Masahiro Masuda Date: Sat May 29 16:29:53 2021 +0900 collect indices and scores in one kernel commit 2b441c391a25930092b55f12dadc31832400277b Author: Masahiro Masuda Date: Sat May 29 15:47:31 2021 +0900 initialization bug fixed in cuda commit d43e801289621e71a79c71308dedeef0969264be Author: Masahiro Masuda Date: Sat May 29 15:23:09 2021 +0900 cpu nms bug fixed commit 025010e42110388d0de2bc2ffcd76fbe14a188fb Author: Masahiro Masuda Date: Sat May 29 11:09:47 2021 +0900 add cpu impl commit 787d8399ff160694ecd2c4a9721a5825ca945d81 Author: Masahiro Masuda Date: Sat May 29 10:38:20 2021 +0900 refactoring commit 05404305d1323b475829de10aa68f1f8791686cc Author: Masahiro Masuda Date: Sat May 29 10:03:51 2021 +0900 initial import commit 5ff0985625ec75f117af37017ebf4089dafb8a46 Author: Masahiro Masuda Date: Sat May 29 10:02:45 2021 +0900 cleanup commit 199f9b67c2d471a761f743e6ea5fa414c899bd3f Author: Masahiro Masuda Date: Sat May 29 10:00:15 2021 +0900 Revert "add gather_nd shape func" This reverts commit 1ff4d53f057e7bfd1c6dff31a81f866727bef855. commit 47a05c4c8f5a56a1685848210229aaa083b92880 Author: Masahiro Masuda Date: Sat May 29 09:53:00 2021 +0900 format commit 9dcd0f02b25d658c94fa23e2cc65a9424ed8a1a5 Author: Masahiro Masuda Date: Sat May 29 09:48:43 2021 +0900 make it static commit eb06393939f1b8d8130f3815dc0f66223c9aa4f3 Author: Masahiro Masuda Date: Sat May 29 09:14:31 2021 +0900 restore old impl and use it for q != 1 case commit 115a5dfcf9b552fb2682534d82bbf638e661c0aa Author: Masahiro Masuda Date: Sat May 29 09:00:40 2021 +0900 fixed score gathering commit d2035626a72f8df71c514e3293337f6fda723353 Author: Masahiro Masuda Date: Sat May 29 08:53:14 2021 +0900 minimum fixed commit 3fe91e8846b6d2075ae1d9a162c4b70b08cc8024 Author: Masahiro Masuda Date: Sat May 29 06:59:39 2021 +0900 batch issue fixed commit 19e3e84690c0289c85001597046969d0c8dc92c2 Author: Masahiro Masuda Date: Sat May 29 04:29:15 2021 +0900 zero padding working This reverts commit 58c3413a30e5b03208b6281651d38ee02c44f9c1. commit ce7848ba7def5a22659b09de039b2df12c0114f9 Author: Masahiro Masuda Date: Fri May 28 13:12:47 2021 +0900 pylint, do not use -1 for default value commit 968f3bd230ed4855b45fd739dfc86edc2335ec80 Author: Masahiro Masuda Date: Fri May 28 13:07:31 2021 +0900 rename to index_rank and make it Optional commit 9e06b8491e0ce1c981a5059f28135319f96978d0 Author: Masahiro Masuda Date: Fri May 21 18:01:59 2021 +0900 fix pylint commit 81dc6050dcbe59915a8f9b4f78b4aaf9fdba89a6 Author: Masahiro Masuda Date: Fri May 21 17:57:03 2021 +0900 minor fix commit 54297b6128863d07e7dded71cc40077726faf2db Author: Masahiro Masuda Date: Fri May 21 17:54:16 2021 +0900 support dynamic scatter nd commit e25c225ce747c4e84452e6e7b32eeb0d71b2995d Author: Masahiro Masuda Date: Fri May 21 17:33:19 2021 +0900 gather_dim -> num_indices_per_tuple commit aaa6211e7ef3ce520b8711a78cf7eb2af52e7acc Author: Masahiro Masuda Date: Fri May 21 17:23:46 2021 +0900 add dynamic gather_nd test commit 3a9fe5dfa5faeadbcdb882ff70039ea7bccb61a3 Author: Masahiro Masuda Date: Fri May 21 17:18:26 2021 +0900 refactor gather_nd ref funcs commit 1ff4d53f057e7bfd1c6dff31a81f866727bef855 Author: Masahiro Masuda Date: Fri May 21 14:36:34 2021 +0900 add gather_nd shape func commit b0200643a184294a2f2b3cce7208c4d257987424 Author: Masahiro Masuda Date: Sat May 29 04:01:11 2021 +0900 working on zero padding commit 456741790dd5e73f3f76ef7a5ede6e1014de8b2d Author: Masahiro Masuda Date: Sat May 29 03:21:52 2021 +0900 working commit 7f5c76d0090950985888781f071ca341e2fa5695 Author: Masahiro Masuda Date: Sat May 29 02:37:50 2021 +0900 relay type inference works, debugging topi commit 4a4b8dfbfdc65d7a6e77ed0a6e8b09af162b77ad Author: Masahiro Masuda Date: Fri May 28 15:08:16 2021 +0900 add max_total_size to attributes commit 7218b2f7b4de0c796d69f23084cd688e28f7b461 Author: Masahiro Masuda Date: Fri May 28 14:50:58 2021 +0900 tf frontend update commit cde4a1fdd15ed898b1f7299e99377cceaaee2732 Author: Masahiro Masuda Date: Fri May 28 14:17:14 2021 +0900 all class nms tf mode first cut commit 5f349f77c9c230ee636aceb52547502319c8ad77 Author: Masahiro Masuda Date: Fri May 28 06:54:34 2021 +0900 begin supporting per batch output commit 0044365affac6667a02d15791a59040702f8990b Author: Trevor Morris Date: Mon May 3 19:46:28 2021 +0000 initial commit 168a617e48b062417b766d6400b0c6b856084cfa Author: Trevor Morris Date: Fri Apr 16 20:31:32 2021 +0000 initia; l commit da75b2a52e9a8daa322168d1b6026e144d42d5bb Author: Masahiro Masuda Date: Sun May 30 07:58:19 2021 +0900 do minimum in topi commit 52c5e8a5bca56f93778990d4faa87c7e7b342ba7 Author: Masahiro Masuda Date: Sun May 30 07:54:49 2021 +0900 more simplify commit 44d88cdecd87468b630fd16a7d1e1214e86eabfa Author: Masahiro Masuda Date: Sun May 30 07:51:39 2021 +0900 simplify commit 74e19174f1b2d40ee8f4d08c7a61667bc3dd69b5 Author: Masahiro Masuda Date: Sun May 30 07:39:37 2021 +0900 black commit fc3a38e1cb699b66340c7742cb74188fdbe92bf5 Author: Masahiro Masuda Date: Sun May 30 07:37:30 2021 +0900 minor change commit f88e2a3a98a7ee283622e57712e28634374e5e2c Author: Masahiro Masuda Date: Sun May 30 07:14:54 2021 +0900 minor refactor commit f2d7ed410a0b835586929706873ee1f448d1f955 Author: Masahiro Masuda Date: Sun May 30 07:08:47 2021 +0900 support the case when there is not enough box commit 0f184a6bf6e533c91d26c98a7a17f8d7970364cc Author: Masahiro Masuda Date: Sun May 30 06:24:16 2021 +0900 Revert "handling case when num detections is smaller than max_total_size" This reverts commit 61e70b82f338300224b22f4d6bdda349e7aa5aca. commit d7180f27cfaffbbd1ab1ce970ca605133bc812ee Merge: 61e70b82f 06ac2052a Author: Masahiro Masuda Date: Sun May 30 05:43:37 2021 +0900 Merge branch 'gather_nd_shape_func' into tmp commit 61e70b82f338300224b22f4d6bdda349e7aa5aca Author: Masahiro Masuda Date: Sun May 30 05:43:06 2021 +0900 handling case when num detections is smaller than max_total_size commit 453a79bd05f67653be8b90db80ecde12d343aea6 Author: Masahiro Masuda Date: Sun May 30 05:32:37 2021 +0900 simplify frontend commit 2fc5f1ed3de49266f1eb72aed25d26457da78491 Author: Masahiro Masuda Date: Sun May 30 05:25:13 2021 +0900 update op definition commit 8afbd30c0fbbd40902acc4196a18b448f2a93266 Author: Masahiro Masuda Date: Sat May 29 19:47:04 2021 +0900 remove unnecessary mask commit ff870f7e972e289953ca0e5daa444c09e5095efa Author: Masahiro Masuda Date: Sat May 29 16:54:19 2021 +0900 remove in_buffer commit e71b922b6cdf129ef51e91928635374a1f02a6fc Author: Masahiro Masuda Date: Sat May 29 16:48:22 2021 +0900 minor fix commit b02faaead24d2d14d3b67bf04ee23f9df9bfecbe Author: Masahiro Masuda Date: Sat May 29 16:45:37 2021 +0900 make it more readable commit 6baee99ed1b57be8da06c00e17d6b92083668ac0 Author: Masahiro Masuda Date: Sat May 29 16:44:14 2021 +0900 clean up commit 7a2a2df8b696faf7c4280fd9a3f9fbdf8f5c3e03 Author: Masahiro Masuda Date: Sat May 29 16:33:54 2021 +0900 improve sort on cpu commit afad2a2e920c98d269c6000035f31392cff7b6a3 Author: Masahiro Masuda Date: Sat May 29 16:29:53 2021 +0900 collect indices and scores in one kernel commit c5718e299a82ffe5e60bc1fee679b2b0405346e5 Author: Masahiro Masuda Date: Sat May 29 15:47:31 2021 +0900 initialization bug fixed in cuda commit 5623e3f8f71de1dbec55c83a300fa4131cd82aad Author: Masahiro Masuda Date: Sat May 29 15:23:09 2021 +0900 cpu nms bug fixed commit c40eaecd87513a6869098ed95c03b8553c350414 Author: Masahiro Masuda Date: Sat May 29 11:09:47 2021 +0900 add cpu impl commit 6c7aaeb44f5586b57e7b1bfd7772d1b78a9eae1f Author: Masahiro Masuda Date: Sat May 29 10:38:20 2021 +0900 refactoring commit 7b87922279121f06cdcc77a41ac6c8f59b6d5549 Author: Masahiro Masuda Date: Sat May 29 10:03:51 2021 +0900 initial import commit 5ff0985625ec75f117af37017ebf4089dafb8a46 Author: Masahiro Masuda Date: Sat May 29 10:02:45 2021 +0900 cleanup commit 199f9b67c2d471a761f743e6ea5fa414c899bd3f Author: Masahiro Masuda Date: Sat May 29 10:00:15 2021 +0900 Revert "add gather_nd shape func" This reverts commit 1ff4d53f057e7bfd1c6dff31a81f866727bef855. commit 47a05c4c8f5a56a1685848210229aaa083b92880 Author: Masahiro Masuda Date: Sat May 29 09:53:00 2021 +0900 format commit 9dcd0f02b25d658c94fa23e2cc65a9424ed8a1a5 Author: Masahiro Masuda Date: Sat May 29 09:48:43 2021 +0900 make it static commit eb06393939f1b8d8130f3815dc0f66223c9aa4f3 Author: Masahiro Masuda Date: Sat May 29 09:14:31 2021 +0900 restore old impl and use it for q != 1 case commit 115a5dfcf9b552fb2682534d82bbf638e661c0aa Author: Masahiro Masuda Date: Sat May 29 09:00:40 2021 +0900 fixed score gathering commit d2035626a72f8df71c514e3293337f6fda723353 Author: Masahiro Masuda Date: Sat May 29 08:53:14 2021 +0900 minimum fixed commit 3fe91e8846b6d2075ae1d9a162c4b70b08cc8024 Author: Masahiro Masuda Date: Sat May 29 06:59:39 2021 +0900 batch issue fixed commit 19e3e84690c0289c85001597046969d0c8dc92c2 Author: Masahiro Masuda Date: Sat May 29 04:29:15 2021 +0900 zero padding working This reverts commit 58c3413a30e5b03208b6281651d38ee02c44f9c1. commit ce7848ba7def5a22659b09de039b2df12c0114f9 Author: Masahiro Masuda Date: Fri May 28 13:12:47 2021 +0900 pylint, do not use -1 for default value commit 968f3bd230ed4855b45fd739dfc86edc2335ec80 Author: Masahiro Masuda Date: Fri May 28 13:07:31 2021 +0900 rename to index_rank and make it Optional commit 9e06b8491e0ce1c981a5059f28135319f96978d0 Author: Masahiro Masuda Date: Fri May 21 18:01:59 2021 +0900 fix pylint commit 81dc6050dcbe59915a8f9b4f78b4aaf9fdba89a6 Author: Masahiro Masuda Date: Fri May 21 17:57:03 2021 +0900 minor fix commit 54297b6128863d07e7dded71cc40077726faf2db Author: Masahiro Masuda Date: Fri May 21 17:54:16 2021 +0900 support dynamic scatter nd commit e25c225ce747c4e84452e6e7b32eeb0d71b2995d Author: Masahiro Masuda Date: Fri May 21 17:33:19 2021 +0900 gather_dim -> num_indices_per_tuple commit aaa6211e7ef3ce520b8711a78cf7eb2af52e7acc Author: Masahiro Masuda Date: Fri May 21 17:23:46 2021 +0900 add dynamic gather_nd test commit 3a9fe5dfa5faeadbcdb882ff70039ea7bccb61a3 Author: Masahiro Masuda Date: Fri May 21 17:18:26 2021 +0900 refactor gather_nd ref funcs commit 1ff4d53f057e7bfd1c6dff31a81f866727bef855 Author: Masahiro Masuda Date: Fri May 21 14:36:34 2021 +0900 add gather_nd shape func commit b0200643a184294a2f2b3cce7208c4d257987424 Author: Masahiro Masuda Date: Sat May 29 04:01:11 2021 +0900 working on zero padding commit 456741790dd5e73f3f76ef7a5ede6e1014de8b2d Author: Masahiro Masuda Date: Sat May 29 03:21:52 2021 +0900 working commit 7f5c76d0090950985888781f071ca341e2fa5695 Author: Masahiro Masuda Date: Sat May 29 02:37:50 2021 +0900 relay type inference works, debugging topi commit 4a4b8dfbfdc65d7a6e77ed0a6e8b09af162b77ad Author: Masahiro Masuda Date: Fri May 28 15:08:16 2021 +0900 add max_total_size to attributes commit 7218b2f7b4de0c796d69f23084cd688e28f7b461 Author: Masahiro Masuda Date: Fri May 28 14:50:58 2021 +0900 tf frontend update commit cde4a1fdd15ed898b1f7299e99377cceaaee2732 Author: Masahiro Masuda Date: Fri May 28 14:17:14 2021 +0900 all class nms tf mode first cut commit 5f349f77c9c230ee636aceb52547502319c8ad77 Author: Masahiro Masuda Date: Fri May 28 06:54:34 2021 +0900 begin supporting per batch output commit 0044365affac6667a02d15791a59040702f8990b Author: Trevor Morris Date: Mon May 3 19:46:28 2021 +0000 initial commit 168a617e48b062417b766d6400b0c6b856084cfa Author: Trevor Morris Date: Fri Apr 16 20:31:32 2021 +0000 initia; l commit 06ac2052ab843be950ff3abf6ce8d52803adc5e5 Author: Masahiro Masuda Date: Fri May 28 13:12:47 2021 +0900 pylint, do not use -1 for default value commit 2adc42618580c967bd49d53c0724382f9cf87772 Author: Masahiro Masuda Date: Fri May 28 13:07:31 2021 +0900 rename to index_rank and make it Optional commit c458da6e80b0ff7b6e2ca729a49755f42dfe3702 Author: Masahiro Masuda Date: Fri May 21 18:01:59 2021 +0900 fix pylint commit b7faf0f93bd3ba4fc0eb88f1fac31c8d9525c883 Author: Masahiro Masuda Date: Fri May 21 17:57:03 2021 +0900 minor fix commit c03164116046670963f1d04529bfe94c5030ad17 Author: Masahiro Masuda Date: Fri May 21 17:54:16 2021 +0900 support dynamic scatter nd commit 56f3f0ea3fae4ba049101fcb4571b8999a3bda1c Author: Masahiro Masuda Date: Fri May 21 17:33:19 2021 +0900 gather_dim -> num_indices_per_tuple commit 081823b0129093602bb7f512f326eeb10bfb1906 Author: Masahiro Masuda Date: Fri May 21 17:23:46 2021 +0900 add dynamic gather_nd test commit 6b2655baf867b4d08e7d21ffe5f854228ced57e9 Author: Masahiro Masuda Date: Fri May 21 17:18:26 2021 +0900 refactor gather_nd ref funcs commit f9f5dfbe2a65eff8aa6718bf05fd8a843c5df08f Author: Masahiro Masuda Date: Fri May 21 14:36:34 2021 +0900 add gather_nd shape func --- include/tvm/relay/attrs/vision.h | 8 +- python/tvm/relay/frontend/tensorflow.py | 98 ++++++++++++++++- python/tvm/relay/op/strategy/generic.py | 12 ++- python/tvm/relay/op/vision/nms.py | 23 +++- python/tvm/topi/cuda/nms.py | 101 ++++++++++++++++-- python/tvm/topi/cuda/vision.py | 4 +- python/tvm/topi/transform.py | 4 +- python/tvm/topi/vision/nms.py | 99 +++++++++++++++-- python/tvm/topi/vision/nms_util.py | 82 ++++++++------ src/relay/op/vision/nms.cc | 39 +++++-- src/topi/transform.cc | 2 +- .../frontend/tensorflow/test_forward.py | 1 - 12 files changed, 402 insertions(+), 71 deletions(-) diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 005b900d5d44..3a61f18eb36e 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -117,8 +117,14 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode { + Optional max_total_size; + std::string output_format; + TVM_DECLARE_ATTRS(AllClassNonMaximumSuppressionAttrs, - "relay.attrs.AllClassNonMaximumSuppressionAttrs") {} + "relay.attrs.AllClassNonMaximumSuppressionAttrs") { + TVM_ATTR_FIELD(max_total_size).set_default(NullValue()).describe("TODO"); + TVM_ATTR_FIELD(output_format).set_default("onnx").describe("Output format. onnx or tensorflow"); + } }; /*! \brief Attributes used in roi_align operators */ diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 040f8384dbe0..4a0f0670baec 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -794,6 +794,85 @@ def _impl(inputs, attr, params, mod): def _combined_nms(): + def all_class_impl( + batch_size, + max_output_boxes_per_batch, + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + max_total_size, + clip_boxes, + ): + ( + selected_indices, + selected_scores, + num_detections, + ) = _op.vision.all_class_non_max_suppression( + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + max_total_size, + output_format="tensorflow", + ) + box_range = _op.arange( + _op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"), dtype="int64" + ) + tile_batch_reps = ( + _op.concatenate([batch_size, 1]) + if isinstance(batch_size, tvm.tir.Any) + else _op.const([batch_size, 1]) + ) + box_range_2d = _op.tile(box_range, tile_batch_reps) + valid_mask = _op.cast( + _op.less(box_range_2d, _op.expand_dims(num_detections, axis=1)), "float32" + ) + + def select_topk(do_zero_pad): + def true_branch(): + arange = _op.arange( + _op.const(0, dtype="int64"), + _op.const(max_output_boxes_per_batch, dtype="int64"), + dtype="int64", + ) + pad = _op.full( + _op.const(0, dtype="int64"), (max_total_size - max_output_boxes_per_batch,) + ) + topk_indices = _op.tile(_op.concatenate([arange, pad], 0), tile_batch_reps) + nmsed_scores = _op.gather(selected_scores, 1, topk_indices) + nmsed_scores = nmsed_scores * valid_mask + return nmsed_scores, topk_indices + + def false_branch(): + return _op.topk(selected_scores, k=max_total_size, axis=1, ret_type="both") + + # TODO(masahi): support dynamic num_boxes + # return _expr.If(do_zero_pad, true_branch(), false_branch()) + return true_branch() if do_zero_pad else false_branch() + + assert isinstance( + max_output_boxes_per_batch, int + ), "dynamic number of boxes not supported yet." + nmsed_scores, topk_indices = select_topk(max_output_boxes_per_batch < max_total_size) + + indices = _op.take(selected_indices, topk_indices, axis=1, batch_dims=1) + nmsed_box_indices = _op.take(indices, _op.const(1), axis=2) + nmsed_classes = _op.take(indices, _op.const(0), axis=2) + nmsed_boxes = _op.take(boxes, nmsed_box_indices, axis=1, batch_dims=1) + + if clip_boxes: + nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32")) + nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32")) + + nmsed_boxes = nmsed_boxes * _op.expand_dims(valid_mask, axis=2) + + return _expr.TupleWrapper( + _expr.Tuple([nmsed_boxes, nmsed_scores, nmsed_classes, num_detections]), 4 + ) + def _impl(inputs, attr, params, mod): # Get parameter values boxes = inputs[0] @@ -821,9 +900,22 @@ def _impl(inputs, attr, params, mod): q = boxes_shape[2] num_classes = scores_shape[2] - if q != num_classes: - # When q is 1, it means same box coords are used for all classes. - boxes = _op.broadcast_to(boxes, (batch_size, num_anchors, num_classes, 4)) + if q == 1 and isinstance(num_anchors, int): + boxes = _op.squeeze(boxes, axis=[2]) + scores_trans = _op.transpose(scores, [0, 2, 1]) + max_output_boxes_per_batch = num_anchors * num_classes + return all_class_impl( + batch_size, + max_output_boxes_per_batch, + boxes, + scores_trans, + max_output_size, + iou_threshold, + score_threshold, + max_total_size.data.numpy().item(), + attr["clip_boxes"], + ) + boxes = _op.reshape(boxes, newshape=[batch_size, num_anchors * num_classes, 4]) scores = _op.reshape(scores, newshape=[batch_size, num_anchors * num_classes, 1]) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 0d6c3ef58cdf..451d01a4fc05 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1095,7 +1095,17 @@ def _compute_nms(attrs, inputs, out_type): max_output_size = inputs[2] iou_threshold = inputs[3] score_threshold = inputs[4] - return topi_compute(inputs[0], inputs[1], max_output_size, iou_threshold, score_threshold) + max_total_size = attrs.max_total_size + output_format = attrs.output_format + return topi_compute( + inputs[0], + inputs[1], + max_output_size, + iou_threshold, + score_threshold, + max_total_size, + output_format, + ) return _compute_nms diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index 3f829e0b1cc7..dc1a08e4b6ac 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -152,7 +152,13 @@ def non_max_suppression( def all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class=-1, iou_threshold=-1.0, score_threshold=-1.0 + boxes, + scores, + max_output_boxes_per_class=-1, + iou_threshold=-1.0, + score_threshold=-1.0, + max_total_size=None, + output_format="onnx", ): """Non-maximum suppression operator for object detection, corresponding to ONNX NonMaxSuppression and TensorFlow combined_non_max_suppression. @@ -185,6 +191,7 @@ def all_class_non_max_suppression( in descending of scores, followed by boxes from batch 0, class 1 etc. Out of `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` rows are valid. + TODO(trvmorr): explain tf mode """ if not isinstance(max_output_boxes_per_class, expr.Expr): max_output_boxes_per_class = expr.const(max_output_boxes_per_class, "int32") @@ -194,6 +201,16 @@ def all_class_non_max_suppression( score_threshold = expr.const(score_threshold, "float32") out = _make.all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + max_total_size, + output_format, ) - return expr.TupleWrapper(out, 2) + + if output_format == "onnx": + return expr.TupleWrapper(out, 2) + + return expr.TupleWrapper(out, 3) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 9a3b86d72b18..2ca247df413f 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -24,6 +24,7 @@ from tvm.ir import register_intrin_lowering from tvm.tir import if_then_else from .sort import argsort, argsort_thrust +from ..broadcast import minimum from .scan import exclusive_scan from ..utils import ceil_div from ..math import cast @@ -32,6 +33,7 @@ calculate_overlap, binary_search, collect_selected_indices, + collect_selected_indices_and_scores, run_all_class_nms, ) @@ -988,8 +990,75 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro return ib.get() +def _collect_selected_indices_and_scores_ir( + selected_indices, + selected_scores, + num_detections, + row_offsets, + num_total_detections, + collected_indices, + collected_scores, +): + batch_size, num_class = row_offsets.shape + num_boxes = selected_indices.shape[1] + + ib = tvm.tir.ir_builder.create() + + selected_indices = ib.buffer_ptr(selected_indices) + selected_scores = ib.buffer_ptr(selected_scores) + num_detections = ib.buffer_ptr(num_detections) + row_offsets = ib.buffer_ptr(row_offsets) + num_total_detections = ib.buffer_ptr(num_total_detections) + collected_indices = ib.buffer_ptr(collected_indices) + collected_scores = ib.buffer_ptr(collected_scores) + + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = ceil_div(num_boxes, nthread_tx) + nthread_by = batch_size * num_class + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(by, "thread_extent", nthread_by) + zero = cast(0, "int64") + + with ib.new_scope(): + idx = bx * nthread_tx + tx + idy = cast(by, "int64") + batch_id = idy // num_class + class_id = idy % num_class + + with ib.if_scope(idx < num_detections[batch_id, class_id]): + offset = row_offsets[batch_id, class_id] + idx + collected_indices[batch_id, offset, 0] = class_id + collected_indices[batch_id, offset, 1] = cast(selected_indices[idy, idx], "int64") + collected_scores[batch_id, offset] = selected_scores[idy, idx] + with ib.else_scope(): + with ib.if_scope(idx < num_boxes): + offset = ( + num_total_detections[batch_id] + + class_id * num_boxes + - row_offsets[batch_id, class_id] + + idx + - num_detections[batch_id, class_id] + ) + collected_indices[batch_id, offset, 0] = zero + collected_indices[batch_id, offset, 1] = zero + collected_scores[batch_id, offset] = 0.0 + + return ib.get() + + def all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + max_total_size=None, + output_format="onnx", ): """Non-maximum suppression operator for object detection, corresponding to ONNX NonMaxSuppression and TensorFlow combined_non_max_suppression. @@ -1012,6 +1081,8 @@ def all_class_non_max_suppression( score_threshold : float or tvm.te.Tensor, optional Score threshold to filter out low score boxes early + output_format : str + Returns ------- out : [tvm.te.Tensor, tvm.te.Tensor] @@ -1029,7 +1100,7 @@ def all_class_non_max_suppression( sorted_scores, sorted_indices = _dispatch_sort(scores, ret_type="both") valid_count = _get_valid_box_count(sorted_scores, score_threshold) - selected_indices, num_detections = run_all_class_nms( + selected_indices, selected_scores, num_detections = run_all_class_nms( boxes, sorted_scores, sorted_indices, @@ -1037,14 +1108,32 @@ def all_class_non_max_suppression( max_output_boxes_per_class, iou_threshold, _nms_loop, + return_scores=(output_format == "tensorflow"), ) + if output_format == "onnx": + row_offsets, num_total_detections = exclusive_scan( + num_detections, return_reduction=True, output_dtype="int64" + ) + selected_indices = collect_selected_indices( + num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + ) + return [selected_indices, num_total_detections] + + num_detections_per_batch = reshape(num_detections, (batch, num_class)) row_offsets, num_total_detections = exclusive_scan( - num_detections, return_reduction=True, output_dtype="int64" + num_detections_per_batch, return_reduction=True, output_dtype="int64", axis=1 ) - selected_indices = collect_selected_indices( - num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + selected_indices, selected_scores = collect_selected_indices_and_scores( + selected_indices, + selected_scores, + num_detections_per_batch, + row_offsets, + num_total_detections, + _collect_selected_indices_and_scores_ir, ) - return [selected_indices, num_total_detections] + num_total_detections = minimum(num_total_detections, max_total_size) + + return [selected_indices, selected_scores, num_total_detections] diff --git a/python/tvm/topi/cuda/vision.py b/python/tvm/topi/cuda/vision.py index 88983ab89f76..5208aeccd413 100644 --- a/python/tvm/topi/cuda/vision.py +++ b/python/tvm/topi/cuda/vision.py @@ -39,7 +39,9 @@ def traverse(op): traverse(tensor.op) scheduled_ops.append(op) - traverse(outs[0].op) + for o in outs: + traverse(o.op) + return s diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index b4d0167be2b1..c4743181ab9c 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -490,7 +490,7 @@ def gather(data, axis, indices): return cpp.gather(data, axis, indices) -def gather_nd(a, indices): +def gather_nd(a, indices, batch_dims=0): """Gather elements from a n-dimension array.. Parameters @@ -505,7 +505,7 @@ def gather_nd(a, indices): ------- ret : tvm.te.Tensor """ - return cpp.gather_nd(a, indices) + return cpp.gather_nd(a, indices, batch_dims) def matmul(a, b, transp_a=False, transp_b=False): diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 744c5ef7feda..82ec6ecccafd 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -22,14 +22,16 @@ from tvm.te import hybrid from tvm.tir import if_then_else -from ..sort import sort, argsort +from ..sort import argsort from ..math import cast -from ..transform import reshape +from ..transform import reshape, gather +from ..broadcast import minimum from .. import reduction from ..scan import cumsum from .nms_util import ( binary_search, collect_selected_indices, + collect_selected_indices_and_scores, run_all_class_nms, ) @@ -727,8 +729,63 @@ def _collect_selected_indices_ir(num_class, selected_indices, num_detections, ro return ib.get() +def _collect_selected_indices_and_scores_ir( + selected_indices, + selected_scores, + num_detections, + row_offsets, + num_total_detections, + collected_indices, + collected_scores, +): + batch_size, num_class = row_offsets.shape + num_boxes = selected_indices.shape[1] + + ib = tvm.tir.ir_builder.create() + + selected_indices = ib.buffer_ptr(selected_indices) + selected_scores = ib.buffer_ptr(selected_scores) + num_detections = ib.buffer_ptr(num_detections) + row_offsets = ib.buffer_ptr(row_offsets) + num_total_detections = ib.buffer_ptr(num_total_detections) + collected_indices = ib.buffer_ptr(collected_indices) + collected_scores = ib.buffer_ptr(collected_scores) + zero = cast(0, "int64") + + with ib.for_range(0, batch_size * num_class, name="i", kind="parallel") as i: + i = cast(i, "int64") + batch_id = i // num_class + class_id = i % num_class + + with ib.for_range(0, num_boxes, name="j") as j: + with ib.if_scope(j < num_detections[batch_id, class_id]): + offset = row_offsets[batch_id, class_id] + j + collected_indices[batch_id, offset, 0] = class_id + collected_indices[batch_id, offset, 1] = cast(selected_indices[i, j], "int64") + collected_scores[batch_id, offset] = selected_scores[i, j] + with ib.else_scope(): + offset = ( + num_total_detections[batch_id] + + class_id * num_boxes + - row_offsets[batch_id, class_id] + + j + - num_detections[batch_id, class_id] + ) + collected_indices[batch_id, offset, 0] = zero + collected_indices[batch_id, offset, 1] = zero + collected_scores[batch_id, offset] = 0.0 + + return ib.get() + + def all_class_non_max_suppression( - boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + max_total_size=None, + output_format="onnx", ): """Non-maximum suppression operator for object detection, corresponding to ONNX NonMaxSuppression and TensorFlow combined_non_max_suppression. @@ -751,6 +808,8 @@ def all_class_non_max_suppression( score_threshold : float or tvm.te.Tensor, optional Score threshold to filter out low score boxes early + output_format : TODO + Returns ------- out : [tvm.te.Tensor, tvm.te.Tensor] @@ -765,11 +824,12 @@ def all_class_non_max_suppression( batch, num_class, num_boxes = scores.shape scores = reshape(scores, (batch * num_class, num_boxes)) - sorted_scores = sort(scores, axis=1, is_ascend=False) sorted_indices = argsort(scores, axis=1, is_ascend=False, dtype="int32") + sorted_scores = gather(scores, 1, sorted_indices) + valid_count = _get_valid_box_count(sorted_scores, score_threshold) - selected_indices, num_detections = run_all_class_nms( + selected_indices, selected_scores, num_detections = run_all_class_nms( boxes, sorted_scores, sorted_indices, @@ -777,14 +837,31 @@ def all_class_non_max_suppression( max_output_boxes_per_class, iou_threshold, _nms_loop, + return_scores=(output_format == "tensorflow"), ) - row_offsets = cumsum(num_detections, exclusive=True, dtype="int64") - - num_total_detections = reduction.sum(cast(num_detections, "int64"), axis=1) + if output_format == "onnx": + row_offsets = cumsum(num_detections, exclusive=True, dtype="int64") + num_total_detections = reduction.sum(cast(num_detections, "int64"), axis=1) - selected_indices = collect_selected_indices( - num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + selected_indices = collect_selected_indices( + num_class, selected_indices, num_detections, row_offsets, _collect_selected_indices_ir + ) + return [selected_indices, num_total_detections] + + num_detections_per_batch = reshape(num_detections, (batch, num_class)) + row_offsets = cumsum(num_detections_per_batch, exclusive=True, dtype="int64", axis=1) + num_total_detections = reduction.sum(cast(num_detections_per_batch, "int64"), axis=1) + + selected_indices, selected_scores = collect_selected_indices_and_scores( + selected_indices, + selected_scores, + num_detections_per_batch, + row_offsets, + num_total_detections, + _collect_selected_indices_and_scores_ir, ) - return [selected_indices, num_total_detections] + num_total_detections = minimum(num_total_detections, max_total_size) + + return [selected_indices, selected_scores, num_total_detections] diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py index 1147b1687783..dfa3c0788295 100644 --- a/python/tvm/topi/vision/nms_util.py +++ b/python/tvm/topi/vision/nms_util.py @@ -106,28 +106,31 @@ def collect_selected_indices(num_class, selected_indices, num_detections, row_of first, in descending of scores, followed by boxes from batch 0, class 1 etc. """ batch_class, num_boxes = selected_indices.shape - - selected_indices_buf = tvm.tir.decl_buffer( - selected_indices.shape, selected_indices.dtype, "selected_indices_buf", data_alignment=8 - ) - num_detections_buf = tvm.tir.decl_buffer( - num_detections.shape, num_detections.dtype, "num_detections_buf", data_alignment=8 - ) - row_offsets_buf = tvm.tir.decl_buffer( - row_offsets.shape, row_offsets.dtype, "row_offsets_buf", data_alignment=8 - ) - return te.extern( [(batch_class * num_boxes, 3)], [selected_indices, num_detections, row_offsets], lambda ins, outs: ir(num_class, ins[0], ins[1], ins[2], outs[0]), dtype=["int64"], - in_buffers=[selected_indices_buf, num_detections_buf, row_offsets_buf], name="collect_indices", tag="collect_indices", ) +def collect_selected_indices_and_scores( + selected_indices, selected_scores, num_detections, row_offsets, num_total_detections, ir +): + batch_size, num_class = row_offsets.shape + num_boxes = selected_indices.shape[1] + return te.extern( + [(batch_size, num_class * num_boxes, 2), (batch_size, num_class * num_boxes)], + [selected_indices, selected_scores, num_detections, row_offsets, num_total_detections], + lambda ins, outs: ir(ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], outs[1]), + dtype=["int64", "float32"], + name="collect_indices_and_scores", + tag="collect_indices_and_scores", + ) + + def _all_class_nms_ir( boxes, sorted_scores, @@ -139,6 +142,7 @@ def _all_class_nms_ir( iou_threshold, max_output_size_per_class, box_indices, + selected_scores, num_valid_boxes, nms_loop, ): @@ -150,6 +154,9 @@ def _all_class_nms_ir( box_indices = ib.buffer_ptr(box_indices) num_valid_boxes = ib.buffer_ptr(num_valid_boxes) + if selected_scores is not None: + selected_scores = ib.buffer_ptr(selected_scores) + if isinstance(iou_threshold, float): iou_threshold = tvm.tir.FloatImm("float32", iou_threshold) @@ -171,6 +178,9 @@ def on_new_valid_box(ib, tid, num_current_valid_box, i, j): with ib.if_scope(tid + 0 == 0): box_indices[i, num_current_valid_box] = sorted_indices[i, j] + if selected_scores is not None: + selected_scores[i, num_current_valid_box] = sorted_scores[i, j] + def on_new_invalidated_box(*_): pass @@ -201,6 +211,7 @@ def run_all_class_nms( max_output_size_per_class, iou_threshold, nms_loop, + return_scores=False, ): """The core all class NMS routine @@ -242,19 +253,33 @@ def run_all_class_nms( batch_class = sorted_scores.shape[0] num_class = batch_class // batch - boxes_buf = tvm.tir.decl_buffer(boxes.shape, boxes.dtype, "boxes_buf", data_alignment=8) - sorted_scores_buf = tvm.tir.decl_buffer( - sorted_scores.shape, sorted_scores.dtype, "sorted_scores_buf", data_alignment=8 - ) - sorted_indices_buf = tvm.tir.decl_buffer( - sorted_indices.shape, sorted_indices.dtype, "sorted_indices_buf", data_alignment=8 - ) - valid_count_buf = tvm.tir.decl_buffer( - valid_count.shape, "int32", "valid_count_buf", data_alignment=4 - ) + if return_scores is False: + selected_indices, num_detections = te.extern( + [(batch_class, num_boxes), (1, batch_class)], + [boxes, sorted_scores, sorted_indices, valid_count], + lambda ins, outs: _all_class_nms_ir( + ins[0], # boxes + ins[1], # sorted_scores + ins[2], # sorted_indices + ins[3], # valid_count + batch_class, + num_class, + num_boxes, + iou_threshold, + max_output_size_per_class, + outs[0], # box_indices + None, # scores + outs[1], # num_selected_boxes + nms_loop, + ), + dtype=["int32", "int32"], + name="all_class_nms", + tag="all_class_nms", + ) + return selected_indices, None, num_detections return te.extern( - [(batch_class, num_boxes), (1, batch_class)], + [(batch_class, num_boxes), (batch_class, num_boxes), (1, batch_class)], [boxes, sorted_scores, sorted_indices, valid_count], lambda ins, outs: _all_class_nms_ir( ins[0], # boxes @@ -267,16 +292,11 @@ def run_all_class_nms( iou_threshold, max_output_size_per_class, outs[0], # box_indices - outs[1], # num_selected_boxes + outs[1], # selected scores + outs[2], # num_selected_boxes nms_loop, ), - dtype=["int32", "int32"], - in_buffers=[ - boxes_buf, - sorted_scores_buf, - sorted_indices_buf, - valid_count_buf, - ], + dtype=["int32", "float32", "int32"], name="all_class_nms", tag="all_class_nms", ) diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 53cd71745d5b..718c9c0a3857 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -152,24 +152,43 @@ bool AllClassNMSRel(const Array& types, int num_inputs, const Attrs& attrs IndexExpr num_classes = scores_shape[1]; IndexExpr num_boxes = boxes_shape[1]; - IndexExpr num_total_boxes = Any(); - if (!batch.as() && !num_boxes.as()) { - num_total_boxes = batch * num_classes * num_boxes; - } + const auto* param = attrs.as(); + CHECK(param); - // assign output type std::vector fields; - std::vector oshape{num_total_boxes, 3}; - fields.push_back(TensorType(oshape, DataType::Int(64))); - std::vector countshape{1}; - fields.push_back(TensorType(countshape, DataType::Int(64))); + if (param->output_format == "onnx") { + IndexExpr num_total_boxes = Any(); + if (!batch.as() && !num_boxes.as()) { + num_total_boxes = batch * num_classes * num_boxes; + } + std::vector oshape{num_total_boxes, 3}; + std::vector counts_shape{1}; + fields.push_back(TensorType(oshape, DataType::Int(64))); + fields.push_back(TensorType(counts_shape, DataType::Int(64))); + } else { + ICHECK(param->max_total_size) << "max_total_size required for tf mode"; + Integer max_total_size = param->max_total_size.value(); + IndexExpr num_total_boxes_per_batch = Any(); + if (!num_boxes.as()) { + num_total_boxes_per_batch = num_classes * num_boxes; + } + std::vector indices_shape{batch, num_total_boxes_per_batch, 2}; + std::vector scores_shape{batch, num_total_boxes_per_batch}; + std::vector counts_shape{batch}; + fields.push_back(TensorType(indices_shape, DataType::Int(64))); + fields.push_back(TensorType(scores_shape, DataType::Float(32))); + fields.push_back(TensorType(counts_shape, DataType::Int(64))); + } reporter->Assign(types[5], TupleType(Array(fields))); return true; } Expr MakeAllClassNMS(Expr boxes, Expr scores, Expr max_output_boxes_per_class, Expr iou_threshold, - Expr score_threshold) { + Expr score_threshold, Optional max_total_size = NullValue(), + std::string output_format = "onnx") { auto attrs = make_object(); + attrs->max_total_size = std::move(max_total_size); + attrs->output_format = std::move(output_format); static const Op& op = Op::Get("vision.all_class_non_max_suppression"); return Call(op, {boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold}, Attrs(attrs), {}); diff --git a/src/topi/transform.cc b/src/topi/transform.cc index db54d5a99a91..58ba16871b64 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -131,7 +131,7 @@ TVM_REGISTER_GLOBAL("topi.gather").set_body([](TVMArgs args, TVMRetValue* rv) { }); TVM_REGISTER_GLOBAL("topi.gather_nd").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = gather_nd(args[0], args[1]); + *rv = gather_nd(args[0], args[1], args[2]); }); TVM_REGISTER_GLOBAL("topi.unravel_index").set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index f29450dbb604..ff000a0aa9b7 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -151,7 +151,6 @@ def run_tvm_graph( return vmobj_to_list(result) elif mode == "vm": with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass): - print(mod["main"]) mod = relay.transform.InferType()(mod) vm_exec = relay.vm.compile(mod, target="llvm", params=params) if serialize: From 7ac92dbfd555851d791a0e4cb702d0b8f501fc46 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 2 Jun 2021 17:53:03 +0900 Subject: [PATCH 02/17] make combined nms converter public --- python/tvm/relay/frontend/tensorflow.py | 128 ++++++++++++------------ 1 file changed, 62 insertions(+), 66 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 4a0f0670baec..6e99a0121638 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -793,86 +793,82 @@ def _impl(inputs, attr, params, mod): return _impl -def _combined_nms(): - def all_class_impl( - batch_size, - max_output_boxes_per_batch, +def convert_combined_nms_with_all_class( + batch_size, + max_output_boxes_per_batch, + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + max_total_size, + clip_boxes, +): + (selected_indices, selected_scores, num_detections,) = _op.vision.all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, max_total_size, - clip_boxes, - ): - ( - selected_indices, - selected_scores, - num_detections, - ) = _op.vision.all_class_non_max_suppression( - boxes, - scores, - max_output_boxes_per_class, - iou_threshold, - score_threshold, - max_total_size, - output_format="tensorflow", - ) - box_range = _op.arange( - _op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"), dtype="int64" - ) - tile_batch_reps = ( - _op.concatenate([batch_size, 1]) - if isinstance(batch_size, tvm.tir.Any) - else _op.const([batch_size, 1]) - ) - box_range_2d = _op.tile(box_range, tile_batch_reps) - valid_mask = _op.cast( - _op.less(box_range_2d, _op.expand_dims(num_detections, axis=1)), "float32" - ) + output_format="tensorflow", + ) + box_range = _op.arange( + _op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"), dtype="int64" + ) + tile_batch_reps = ( + _op.concatenate([batch_size, 1]) + if isinstance(batch_size, tvm.tir.Any) + else _op.const([batch_size, 1]) + ) + box_range_2d = _op.tile(box_range, tile_batch_reps) + valid_mask = _op.cast( + _op.less(box_range_2d, _op.expand_dims(num_detections, axis=1)), "float32" + ) - def select_topk(do_zero_pad): - def true_branch(): - arange = _op.arange( - _op.const(0, dtype="int64"), - _op.const(max_output_boxes_per_batch, dtype="int64"), - dtype="int64", - ) - pad = _op.full( - _op.const(0, dtype="int64"), (max_total_size - max_output_boxes_per_batch,) - ) - topk_indices = _op.tile(_op.concatenate([arange, pad], 0), tile_batch_reps) - nmsed_scores = _op.gather(selected_scores, 1, topk_indices) - nmsed_scores = nmsed_scores * valid_mask - return nmsed_scores, topk_indices + def select_topk(do_zero_pad): + def true_branch(): + arange = _op.arange( + _op.const(0, dtype="int64"), + _op.const(max_output_boxes_per_batch, dtype="int64"), + dtype="int64", + ) + pad = _op.full( + _op.const(0, dtype="int64"), (max_total_size - max_output_boxes_per_batch,) + ) + topk_indices = _op.tile(_op.concatenate([arange, pad], 0), tile_batch_reps) + nmsed_scores = _op.gather(selected_scores, 1, topk_indices) + nmsed_scores = nmsed_scores * valid_mask + return nmsed_scores, topk_indices - def false_branch(): - return _op.topk(selected_scores, k=max_total_size, axis=1, ret_type="both") + def false_branch(): + return _op.topk(selected_scores, k=max_total_size, axis=1, ret_type="both") - # TODO(masahi): support dynamic num_boxes - # return _expr.If(do_zero_pad, true_branch(), false_branch()) - return true_branch() if do_zero_pad else false_branch() + # TODO(masahi): support dynamic num_boxes + # return _expr.If(do_zero_pad, true_branch(), false_branch()) + return true_branch() if do_zero_pad else false_branch() - assert isinstance( - max_output_boxes_per_batch, int - ), "dynamic number of boxes not supported yet." - nmsed_scores, topk_indices = select_topk(max_output_boxes_per_batch < max_total_size) + assert isinstance(max_output_boxes_per_batch, int), "dynamic number of boxes not supported yet." + nmsed_scores, topk_indices = select_topk(max_output_boxes_per_batch < max_total_size) - indices = _op.take(selected_indices, topk_indices, axis=1, batch_dims=1) - nmsed_box_indices = _op.take(indices, _op.const(1), axis=2) - nmsed_classes = _op.take(indices, _op.const(0), axis=2) - nmsed_boxes = _op.take(boxes, nmsed_box_indices, axis=1, batch_dims=1) + indices = _op.take(selected_indices, topk_indices, axis=1, batch_dims=1) + nmsed_box_indices = _op.take(indices, _op.const(1), axis=2) + nmsed_classes = _op.take(indices, _op.const(0), axis=2) + nmsed_classes = _op.cast(nmsed_classes, "float32") + nmsed_boxes = _op.take(boxes, nmsed_box_indices, axis=1, batch_dims=1) - if clip_boxes: - nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32")) - nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32")) + if clip_boxes: + nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32")) + nmsed_boxes = _op.minimum(nmsed_boxes, _expr.const(1, dtype="float32")) - nmsed_boxes = nmsed_boxes * _op.expand_dims(valid_mask, axis=2) + nmsed_boxes = nmsed_boxes * _op.expand_dims(valid_mask, axis=2) - return _expr.TupleWrapper( - _expr.Tuple([nmsed_boxes, nmsed_scores, nmsed_classes, num_detections]), 4 - ) + return _expr.TupleWrapper( + _expr.Tuple([nmsed_boxes, nmsed_scores, nmsed_classes, num_detections]), 4 + ) + +def _combined_nms(): def _impl(inputs, attr, params, mod): # Get parameter values boxes = inputs[0] @@ -904,7 +900,7 @@ def _impl(inputs, attr, params, mod): boxes = _op.squeeze(boxes, axis=[2]) scores_trans = _op.transpose(scores, [0, 2, 1]) max_output_boxes_per_batch = num_anchors * num_classes - return all_class_impl( + return convert_combined_nms_with_all_class( batch_size, max_output_boxes_per_batch, boxes, From 090616ca0e75351e1b062823e7c0b1ac8625077c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 2 Jun 2021 17:54:28 +0900 Subject: [PATCH 03/17] do topk on smaller score tensor --- python/tvm/relay/frontend/tensorflow.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 6e99a0121638..ea9168072d3c 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -796,6 +796,7 @@ def _impl(inputs, attr, params, mod): def convert_combined_nms_with_all_class( batch_size, max_output_boxes_per_batch, + num_class, boxes, scores, max_output_boxes_per_class, @@ -842,7 +843,11 @@ def true_branch(): return nmsed_scores, topk_indices def false_branch(): - return _op.topk(selected_scores, k=max_total_size, axis=1, ret_type="both") + slice_mx = _op.const([max_output_boxes_per_class * num_class], dtype="int64") + selected_scores_slice = _op.strided_slice( + selected_scores, begin=_op.const([0], dtype="int64"), end=slice_mx, axes=[1] + ) + return _op.topk(selected_scores_slice, k=max_total_size, axis=1, ret_type="both") # TODO(masahi): support dynamic num_boxes # return _expr.If(do_zero_pad, true_branch(), false_branch()) @@ -903,6 +908,7 @@ def _impl(inputs, attr, params, mod): return convert_combined_nms_with_all_class( batch_size, max_output_boxes_per_batch, + num_classes, boxes, scores_trans, max_output_size, From fd49349573e330672b06e631a08f25d053a311dd Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 2 Jun 2021 17:59:07 +0900 Subject: [PATCH 04/17] update tests --- tests/python/frontend/tensorflow/test_forward.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index ff000a0aa9b7..331553388b48 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -3437,16 +3437,18 @@ def _test_forward_combined_nms( "nms/CombinedNonMaxSuppression:2", "nms/CombinedNonMaxSuppression:3", ], - mode="vm", ) def test_forward_combined_nms(): """CombinedNonMaxSuppression""" _test_forward_combined_nms((1, 64, 1, 4), (1, 64, 1), 0.7, 0.5, 64, 64) + _test_forward_combined_nms((1, 32, 1, 4), (1, 32, 1), 0.7, 0.5, 10, 64) + _test_forward_combined_nms((1, 32, 1, 4), (1, 32, 2), 0.7, 0.5, 32, 64) _test_forward_combined_nms((1, 64, 1, 4), (1, 64, 20), 0.7, 0.5, 64, 10) _test_forward_combined_nms((1, 64, 20, 4), (1, 64, 20), 0.7, 0.5, 64, 64, clip_boxes=True) _test_forward_combined_nms((2, 200, 1, 4), (2, 200, 1), 0.4, 0.6, 100, 100) + _test_forward_combined_nms((2, 200, 1, 4), (2, 200, 10), 0.4, 0.2, 150, 1000) ####################################################################### From 7716467a475649f1f4d315169dbe67a4b1631588 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 2 Jun 2021 18:07:54 +0900 Subject: [PATCH 05/17] remove max_total_size attribute, do minimum in relay side --- include/tvm/relay/attrs/vision.h | 10 ++++++---- python/tvm/relay/frontend/tensorflow.py | 2 +- python/tvm/relay/op/vision/nms.py | 2 -- python/tvm/topi/cuda/nms.py | 3 --- python/tvm/topi/vision/nms.py | 3 --- src/relay/op/vision/nms.cc | 6 +----- 6 files changed, 8 insertions(+), 18 deletions(-) diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 3a61f18eb36e..0b3a85662e99 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -114,16 +114,18 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode { - Optional max_total_size; std::string output_format; TVM_DECLARE_ATTRS(AllClassNonMaximumSuppressionAttrs, "relay.attrs.AllClassNonMaximumSuppressionAttrs") { - TVM_ATTR_FIELD(max_total_size).set_default(NullValue()).describe("TODO"); - TVM_ATTR_FIELD(output_format).set_default("onnx").describe("Output format. onnx or tensorflow"); + TVM_ATTR_FIELD(output_format) + .set_default("onnx") + .describe( + "Output format, onnx or tensorflow. Returns outputs so that they can be easily " + "consumed by each frontend."); } }; diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index ea9168072d3c..8e13fe015f0e 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -811,7 +811,6 @@ def convert_combined_nms_with_all_class( max_output_boxes_per_class, iou_threshold, score_threshold, - max_total_size, output_format="tensorflow", ) box_range = _op.arange( @@ -861,6 +860,7 @@ def false_branch(): nmsed_classes = _op.take(indices, _op.const(0), axis=2) nmsed_classes = _op.cast(nmsed_classes, "float32") nmsed_boxes = _op.take(boxes, nmsed_box_indices, axis=1, batch_dims=1) + num_detections = _op.minimum(num_detections, _op.const(max_total_size, dtype="int64")) if clip_boxes: nmsed_boxes = _op.maximum(nmsed_boxes, _expr.const(0, dtype="float32")) diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index dc1a08e4b6ac..e5aaf71b6dcc 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -157,7 +157,6 @@ def all_class_non_max_suppression( max_output_boxes_per_class=-1, iou_threshold=-1.0, score_threshold=-1.0, - max_total_size=None, output_format="onnx", ): """Non-maximum suppression operator for object detection, corresponding to ONNX @@ -206,7 +205,6 @@ def all_class_non_max_suppression( max_output_boxes_per_class, iou_threshold, score_threshold, - max_total_size, output_format, ) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 2ca247df413f..1bdb69154394 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -1057,7 +1057,6 @@ def all_class_non_max_suppression( max_output_boxes_per_class, iou_threshold, score_threshold, - max_total_size=None, output_format="onnx", ): """Non-maximum suppression operator for object detection, corresponding to ONNX @@ -1134,6 +1133,4 @@ def all_class_non_max_suppression( _collect_selected_indices_and_scores_ir, ) - num_total_detections = minimum(num_total_detections, max_total_size) - return [selected_indices, selected_scores, num_total_detections] diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 82ec6ecccafd..7b7df2e4d484 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -784,7 +784,6 @@ def all_class_non_max_suppression( max_output_boxes_per_class, iou_threshold, score_threshold, - max_total_size=None, output_format="onnx", ): """Non-maximum suppression operator for object detection, corresponding to ONNX @@ -862,6 +861,4 @@ def all_class_non_max_suppression( _collect_selected_indices_and_scores_ir, ) - num_total_detections = minimum(num_total_detections, max_total_size) - return [selected_indices, selected_scores, num_total_detections] diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 718c9c0a3857..8c33c1648cf3 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -166,8 +166,6 @@ bool AllClassNMSRel(const Array& types, int num_inputs, const Attrs& attrs fields.push_back(TensorType(oshape, DataType::Int(64))); fields.push_back(TensorType(counts_shape, DataType::Int(64))); } else { - ICHECK(param->max_total_size) << "max_total_size required for tf mode"; - Integer max_total_size = param->max_total_size.value(); IndexExpr num_total_boxes_per_batch = Any(); if (!num_boxes.as()) { num_total_boxes_per_batch = num_classes * num_boxes; @@ -184,10 +182,8 @@ bool AllClassNMSRel(const Array& types, int num_inputs, const Attrs& attrs } Expr MakeAllClassNMS(Expr boxes, Expr scores, Expr max_output_boxes_per_class, Expr iou_threshold, - Expr score_threshold, Optional max_total_size = NullValue(), - std::string output_format = "onnx") { + Expr score_threshold, std::string output_format = "onnx") { auto attrs = make_object(); - attrs->max_total_size = std::move(max_total_size); attrs->output_format = std::move(output_format); static const Op& op = Op::Get("vision.all_class_non_max_suppression"); return Call(op, {boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold}, From 839963faaef389f28901bdd2c4f105841a467a35 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 2 Jun 2021 18:23:59 +0900 Subject: [PATCH 06/17] fix topk --- python/tvm/relay/frontend/tensorflow.py | 12 ++++++++---- python/tvm/relay/op/strategy/generic.py | 2 -- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 8e13fe015f0e..a4f7d4da96ef 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -842,10 +842,14 @@ def true_branch(): return nmsed_scores, topk_indices def false_branch(): - slice_mx = _op.const([max_output_boxes_per_class * num_class], dtype="int64") - selected_scores_slice = _op.strided_slice( - selected_scores, begin=_op.const([0], dtype="int64"), end=slice_mx, axes=[1] - ) + if isinstance(max_output_boxes_per_class, int): + # Do topk on smaller input if possible + slice_mx = _op.const([max_output_boxes_per_class * num_class], dtype="int64") + selected_scores_slice = _op.strided_slice( + selected_scores, begin=_op.const([0], dtype="int64"), end=slice_mx, axes=[1] + ) + else: + selected_scores_slice = selected_scores return _op.topk(selected_scores_slice, k=max_total_size, axis=1, ret_type="both") # TODO(masahi): support dynamic num_boxes diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 451d01a4fc05..d56820e409aa 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1095,7 +1095,6 @@ def _compute_nms(attrs, inputs, out_type): max_output_size = inputs[2] iou_threshold = inputs[3] score_threshold = inputs[4] - max_total_size = attrs.max_total_size output_format = attrs.output_format return topi_compute( inputs[0], @@ -1103,7 +1102,6 @@ def _compute_nms(attrs, inputs, out_type): max_output_size, iou_threshold, score_threshold, - max_total_size, output_format, ) From 16f76538cc57d2e5b3b352e4f9ffe0457eb6baf1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 2 Jun 2021 18:50:35 +0900 Subject: [PATCH 07/17] update relay doc --- python/tvm/relay/op/vision/nms.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index e5aaf71b6dcc..1cf18e554d49 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -180,17 +180,31 @@ def all_class_non_max_suppression( score_threshold : float or relay.Expr, optional Score threshold to filter out low score boxes early + output_format : string, optional + "onnx" or "tensorflow". Specify by which frontends the outputs are + intented to be consumed. + Returns ------- out : relay.Tuple - The output is a relay.Tuple of two tensors, the first is `indices` of size - `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor - `num_total_detection` of shape `(1,)` representing the total number of selected boxes. + If `output_format` is 'onnx', the output is a relay.Tuple of two tensors, the first is + `indices` of size `(batch_size * num_class* num_boxes , 3)` and the second is a scalar + tensor `num_total_detection` of shape `(1,)` representing the total number of selected + boxes. The three values in `indices` encode batch, class, and box indices. Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` rows are valid. - TODO(trvmorr): explain tf mode + + If `output_format` is 'tensorflow', the output is a relay.Tuple of three tensors, the first + is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the second is `scores` of + size `(batch_size, num_class * num_boxes)`, and the third is `num_total_detection` of size + `(batch_size,)` representing the total number of selected boxes per batch. The two values + in `indices` encode class and box indices. Of num_class * num_boxes boxes in `indices` at + batch b, only the first `num_total_detection[b]` entries are valid. The second axis of + `indices` and `scores` are sorted within each class by box scores, but not across classes. + So the box indices and scores for the class 0 come first in a sorted order, followed by + the class 1 etc. """ if not isinstance(max_output_boxes_per_class, expr.Expr): max_output_boxes_per_class = expr.const(max_output_boxes_per_class, "int32") From 3299e2bc10c42bd1440b9068dca748c38b8542a1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 2 Jun 2021 19:13:20 +0900 Subject: [PATCH 08/17] update doc --- include/tvm/relay/attrs/vision.h | 2 +- python/tvm/relay/op/vision/nms.py | 4 +-- python/tvm/topi/cuda/nms.py | 20 +++++++++++--- python/tvm/topi/vision/nms.py | 20 +++++++++++--- python/tvm/topi/vision/nms_util.py | 44 +++++++++++++++++++++++++++--- 5 files changed, 75 insertions(+), 15 deletions(-) diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 0b3a85662e99..976304e79c34 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -124,7 +124,7 @@ struct AllClassNonMaximumSuppressionAttrs TVM_ATTR_FIELD(output_format) .set_default("onnx") .describe( - "Output format, onnx or tensorflow. Returns outputs so that they can be easily " + "Output format, onnx or tensorflow. Returns outputs in a way that can be easily " "consumed by each frontend."); } }; diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index 1cf18e554d49..8c54075d952c 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -187,7 +187,7 @@ def all_class_non_max_suppression( Returns ------- out : relay.Tuple - If `output_format` is 'onnx', the output is a relay.Tuple of two tensors, the first is + If `output_format` is "onnx", the output is a relay.Tuple of two tensors, the first is `indices` of size `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor `num_total_detection` of shape `(1,)` representing the total number of selected boxes. The three values in `indices` encode batch, class, and box indices. @@ -196,7 +196,7 @@ def all_class_non_max_suppression( `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` rows are valid. - If `output_format` is 'tensorflow', the output is a relay.Tuple of three tensors, the first + If `output_format` is "tensorflow", the output is a relay.Tuple of three tensors, the first is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the second is `scores` of size `(batch_size, num_class * num_boxes)`, and the third is `num_total_detection` of size `(batch_size,)` representing the total number of selected boxes per batch. The two values diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 1bdb69154394..979fdc27d65f 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -1080,18 +1080,30 @@ def all_class_non_max_suppression( score_threshold : float or tvm.te.Tensor, optional Score threshold to filter out low score boxes early - output_format : str + output_format : str, optional + "onnx" or "tensorflow", see below Returns ------- - out : [tvm.te.Tensor, tvm.te.Tensor] - The output is two tensors, the first is `indices` of size + out : list of tvm.te.Tensor + If `output_format` is "onnx", the output is two tensors. The first is `indices` of size `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor `num_total_detection` of shape `(1,)` representing the total number of selected - boxes. Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come + boxes. The three values in `indices` encode batch, class, and box indices. + Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` rows are valid. + + If `output_format` is "tensorflow", the output is three tensors, the first + is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the second is `scores` of + size `(batch_size, num_class * num_boxes)`, and the third is `num_total_detection` of size + `(batch_size,)` representing the total number of selected boxes per batch. The two values + in `indices` encode class and box indices. Of num_class * num_boxes boxes in `indices` at + batch b, only the first `num_total_detection[b]` entries are valid. The second axis of + `indices` and `scores` are sorted within each class by box scores, but not across classes. + So the box indices and scores for the class 0 come first in a sorted order, followed by + the class 1 etc. """ batch, num_class, num_boxes = scores.shape diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 7b7df2e4d484..40755d810fd8 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -807,18 +807,30 @@ def all_class_non_max_suppression( score_threshold : float or tvm.te.Tensor, optional Score threshold to filter out low score boxes early - output_format : TODO + output_format : str, optional + "onnx" or "tensorflow", see below. Returns ------- - out : [tvm.te.Tensor, tvm.te.Tensor] - The output is two tensors, the first is `indices` of size + out : list of tvm.te.Tensor + If `output_format` is "onnx", the output is two tensors. The first is `indices` of size `(batch_size * num_class* num_boxes , 3)` and the second is a scalar tensor `num_total_detection` of shape `(1,)` representing the total number of selected - boxes. Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come + boxes. The three values in `indices` encode batch, class, and box indices. + Rows of `indices` are ordered such that selected boxes from batch 0, class 0 come first, in descending of scores, followed by boxes from batch 0, class 1 etc. Out of `batch_size * num_class* num_boxes` rows of indices, only the first `num_total_detection` rows are valid. + + If `output_format` is "tensorflow", the output is three tensors, the first + is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the second is `scores` of + size `(batch_size, num_class * num_boxes)`, and the third is `num_total_detection` of size + `(batch_size,)` representing the total number of selected boxes per batch. The two values + in `indices` encode class and box indices. Of num_class * num_boxes boxes in `indices` at + batch b, only the first `num_total_detection[b]` entries are valid. The second axis of + `indices` and `scores` are sorted within each class by box scores, but not across classes. + So the box indices and scores for the class 0 come first in a sorted order, followed by + the class 1 etc. """ batch, num_class, num_boxes = scores.shape scores = reshape(scores, (batch * num_class, num_boxes)) diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py index dfa3c0788295..d12592fd111a 100644 --- a/python/tvm/topi/vision/nms_util.py +++ b/python/tvm/topi/vision/nms_util.py @@ -119,6 +119,38 @@ def collect_selected_indices(num_class, selected_indices, num_detections, row_of def collect_selected_indices_and_scores( selected_indices, selected_scores, num_detections, row_offsets, num_total_detections, ir ): + """Collect selected indices and scores from the core NMS loop into one linear output + + Parameters + ---------- + num_class : int + + selected_indices: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes), representing the indices + of selected boxes by the core NMS loop. + + selected_indices: tvm.te.Tensor + 2-D tensor with shape (batch_size * num_classes, num_boxes), representing the scores + of selected boxes by the core NMS loop. + + num_detections tvm.te.Tensor + 2-D tensor with shape (batch_size, num_classes), representing + the number of boxes selected by the core NMS loop, per batch and class + + row_offsets tvm.te.Tensor + 2-D tensor with shape (batch_size, num_classes), this should be the exclusive scan + of num_detections along axis 1 + + ir : function + A function to generate IR for CPU or GPU, see its usage in vision/nms.py and cuda/nms.py + + Returns + ------- + out : [tvm.te.Tensor, tvm.te.Tensor] + The output is two tensors. The first is indices of size + (batch_size, num_class* num_boxes, 2), and the second is scores of size + (batch_size, num_class* num_boxes). + """ batch_size, num_class = row_offsets.shape num_boxes = selected_indices.shape[1] return te.extern( @@ -241,13 +273,17 @@ def run_all_class_nms( nms_loop : function A core NMS loop, see its usage in vision/nms.py and cuda/nms.py + return_scores : bool, optional + Whether or not to return selected scores, needed by the tensorflow output format. + Returns ------- - out : [tvm.te.Tensor, tvm.te.Tensor] - The output is two tensors, the first is indices of size - (batch_size * num_class, num_boxes) and the second is a tensor + out : a list of tvm.te.Tensor + The output is three tensors, the first and second are indices and scores of size + (batch_size * num_class, num_boxes), and the third is a tensor num_selected_boxes of shape (batch_size * num_class,) representing the total number of - selected boxes per batch and class. + selected boxes per batch and class. If return_scores is False, the second output is + None. """ batch, num_boxes, _ = boxes.shape batch_class = sorted_scores.shape[0] From ef919ae181aa92ed3b38a24798efa84334b6ff4b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 2 Jun 2021 19:26:20 +0900 Subject: [PATCH 09/17] fix pylint --- python/tvm/relay/frontend/tensorflow.py | 8 +++++--- python/tvm/topi/cuda/nms.py | 1 - python/tvm/topi/vision/nms.py | 1 - 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index a4f7d4da96ef..9ad1a1d83033 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -805,6 +805,7 @@ def convert_combined_nms_with_all_class( max_total_size, clip_boxes, ): + """Converts TF combined_nms using Relay all_class_max_suppression op""" (selected_indices, selected_scores, num_detections,) = _op.vision.all_class_non_max_suppression( boxes, scores, @@ -817,7 +818,7 @@ def convert_combined_nms_with_all_class( _op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"), dtype="int64" ) tile_batch_reps = ( - _op.concatenate([batch_size, 1]) + _op.concatenate([batch_size, 1], axis=0) if isinstance(batch_size, tvm.tir.Any) else _op.const([batch_size, 1]) ) @@ -844,9 +845,10 @@ def true_branch(): def false_branch(): if isinstance(max_output_boxes_per_class, int): # Do topk on smaller input if possible - slice_mx = _op.const([max_output_boxes_per_class * num_class], dtype="int64") + # TODO(masahi): use axes argument in strided slice when it becomes available + slice_mx = _op.const([-1, max_output_boxes_per_class * num_class], dtype="int64") selected_scores_slice = _op.strided_slice( - selected_scores, begin=_op.const([0], dtype="int64"), end=slice_mx, axes=[1] + selected_scores, begin=_op.const([0, 0], dtype="int64"), end=slice_mx ) else: selected_scores_slice = selected_scores diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 979fdc27d65f..e402c5888978 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -24,7 +24,6 @@ from tvm.ir import register_intrin_lowering from tvm.tir import if_then_else from .sort import argsort, argsort_thrust -from ..broadcast import minimum from .scan import exclusive_scan from ..utils import ceil_div from ..math import cast diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 40755d810fd8..7a51946a279a 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -25,7 +25,6 @@ from ..sort import argsort from ..math import cast from ..transform import reshape, gather -from ..broadcast import minimum from .. import reduction from ..scan import cumsum from .nms_util import ( From 6a4a0040a71eb93b5c8fd8b04d65008c09ed1875 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 2 Jun 2021 19:44:34 +0900 Subject: [PATCH 10/17] update shape func for tf mode and add test --- python/tvm/relay/op/vision/_vision.py | 22 ++++++++- tests/python/relay/test_any.py | 66 ++++++++++++++++++++------- 2 files changed, 70 insertions(+), 18 deletions(-) diff --git a/python/tvm/relay/op/vision/_vision.py b/python/tvm/relay/op/vision/_vision.py index 8d6abf1a8c20..cab9f703e88a 100644 --- a/python/tvm/relay/op/vision/_vision.py +++ b/python/tvm/relay/op/vision/_vision.py @@ -89,7 +89,7 @@ def nms_shape_func(attrs, inputs, _): @script -def _all_class_nms_shape_func(boxes_shape, scores_shape): +def _all_class_nms_shape_func_onnx(boxes_shape, scores_shape): out_shape = output_tensor((2,), "int64") count_shape = output_tensor((1,), "int64") @@ -99,9 +99,27 @@ def _all_class_nms_shape_func(boxes_shape, scores_shape): return out_shape, count_shape +@script +def _all_class_nms_shape_func_tf(boxes_shape, scores_shape): + out_indices_shape = output_tensor((3,), "int64") + out_scores_shape = output_tensor((2,), "int64") + count_shape = output_tensor((1,), "int64") + + out_indices_shape[0] = boxes_shape[0] + out_indices_shape[1] = scores_shape[1] * boxes_shape[1] + out_indices_shape[2] = int64(2) + out_scores_shape[0] = boxes_shape[0] + out_scores_shape[1] = scores_shape[1] * boxes_shape[1] + count_shape[0] = boxes_shape[0] + + return out_indices_shape, out_scores_shape, count_shape + + @reg.register_shape_func("vision.all_class_non_max_suppression", False) def all_class_nms_shape_func(attrs, inputs, _): - return _all_class_nms_shape_func(inputs[0], inputs[1]) + if attrs.output_format == "onnx": + return _all_class_nms_shape_func_onnx(inputs[0], inputs[1]) + return _all_class_nms_shape_func_tf(inputs[0], inputs[1]) @script diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 74b8ec51e1fa..57f07b3f00e5 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -1611,7 +1611,8 @@ def verify_all_class_non_max_suppression( max_output_boxes_per_class, iou_threshold, score_threshold, - expected_indices, + expected, + output_format="onnx", ): batch_size = boxes_np.shape[0] num_classes = scores_np.shape[1] @@ -1622,23 +1623,23 @@ def verify_all_class_non_max_suppression( ) nms_out = relay.vision.all_class_non_max_suppression( - boxes, - scores, - max_output_boxes_per_class, - iou_threshold, - score_threshold, + boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, output_format ) - three = relay.const(np.array([3]), dtype="int64") - begin = relay.const(np.array([0, 0]), dtype="int64") - end = relay.op.concatenate([nms_out[1], three], axis=0) - strides = relay.const(np.array([1, 1]), dtype="int64") - out = relay.op.strided_slice(nms_out[0], begin, end, strides) - - mod = tvm.IRModule() - mod["main"] = relay.Function([boxes, scores], out) - - check_result([boxes_np, scores_np], mod, [expected_indices]) + if output_format == "onnx": + three = relay.const(np.array([3]), dtype="int64") + begin = relay.const(np.array([0, 0]), dtype="int64") + end = relay.op.concatenate([nms_out[1], three], axis=0) + strides = relay.const(np.array([1, 1]), dtype="int64") + out = relay.op.strided_slice(nms_out[0], begin, end, strides) + mod = tvm.IRModule() + mod["main"] = relay.Function([boxes, scores], out) + check_result([boxes_np, scores_np], mod, [expected]) + else: + out = nms_out.tuple_value + mod = tvm.IRModule() + mod["main"] = relay.Function([boxes, scores], out) + check_result([boxes_np, scores_np], mod, expected) boxes = np.array( [ @@ -1668,6 +1669,39 @@ def verify_all_class_non_max_suppression( boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, expected ) + expected = [ + np.array( + [[[0, 4], [0, 2], [1, 4], [1, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]] + ), + np.array( + [ + [ + 0.9, + 0.6, + 0.9, + 0.8, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ] + ] + ), + np.array([4]), + ] + + verify_all_class_non_max_suppression( + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + expected, + output_format="tensorflow", + ) + boxes = np.array( [ [ From 90b3156fa582b43aa83a87bf58ff1ec5ce33dd63 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 2 Jun 2021 20:50:16 +0900 Subject: [PATCH 11/17] name change --- python/tvm/relay/frontend/tensorflow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 9ad1a1d83033..d45d84506f89 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -793,7 +793,7 @@ def _impl(inputs, attr, params, mod): return _impl -def convert_combined_nms_with_all_class( +def convert_combined_nms_with_all_class_nms( batch_size, max_output_boxes_per_batch, num_class, @@ -911,7 +911,7 @@ def _impl(inputs, attr, params, mod): boxes = _op.squeeze(boxes, axis=[2]) scores_trans = _op.transpose(scores, [0, 2, 1]) max_output_boxes_per_batch = num_anchors * num_classes - return convert_combined_nms_with_all_class( + return convert_combined_nms_with_all_class_nms( batch_size, max_output_boxes_per_batch, num_classes, From a2595bc71731f2a6a55911fc74c13be2f24bee81 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 3 Jun 2021 05:38:56 +0900 Subject: [PATCH 12/17] reject dynamic inputs --- python/tvm/relay/frontend/tensorflow.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index d45d84506f89..4c39a5a4dfec 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -907,7 +907,11 @@ def _impl(inputs, attr, params, mod): q = boxes_shape[2] num_classes = scores_shape[2] - if q == 1 and isinstance(num_anchors, int): + assert isinstance(batch_size, int) and isinstance( + num_anchors, int + ), "Dynamic inputs not supported yet" + + if q == 1: boxes = _op.squeeze(boxes, axis=[2]) scores_trans = _op.transpose(scores, [0, 2, 1]) max_output_boxes_per_batch = num_anchors * num_classes From e88946688d1b8ff5d6946bdf767af3a2ac17ab9a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 3 Jun 2021 05:40:53 +0900 Subject: [PATCH 13/17] revert gather_nd change --- python/tvm/topi/transform.py | 4 ++-- src/topi/transform.cc | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index c4743181ab9c..b4d0167be2b1 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -490,7 +490,7 @@ def gather(data, axis, indices): return cpp.gather(data, axis, indices) -def gather_nd(a, indices, batch_dims=0): +def gather_nd(a, indices): """Gather elements from a n-dimension array.. Parameters @@ -505,7 +505,7 @@ def gather_nd(a, indices, batch_dims=0): ------- ret : tvm.te.Tensor """ - return cpp.gather_nd(a, indices, batch_dims) + return cpp.gather_nd(a, indices) def matmul(a, b, transp_a=False, transp_b=False): diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 58ba16871b64..db54d5a99a91 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -131,7 +131,7 @@ TVM_REGISTER_GLOBAL("topi.gather").set_body([](TVMArgs args, TVMRetValue* rv) { }); TVM_REGISTER_GLOBAL("topi.gather_nd").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = gather_nd(args[0], args[1], args[2]); + *rv = gather_nd(args[0], args[1]); }); TVM_REGISTER_GLOBAL("topi.unravel_index").set_body([](TVMArgs args, TVMRetValue* rv) { From 1fee76c88d11e40414811b93d080afa2b82df9b3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 3 Jun 2021 05:41:38 +0900 Subject: [PATCH 14/17] do not try to support dynamic batch size in tile rep --- python/tvm/relay/frontend/tensorflow.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 4c39a5a4dfec..4599974ac1b0 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -817,11 +817,7 @@ def convert_combined_nms_with_all_class_nms( box_range = _op.arange( _op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"), dtype="int64" ) - tile_batch_reps = ( - _op.concatenate([batch_size, 1], axis=0) - if isinstance(batch_size, tvm.tir.Any) - else _op.const([batch_size, 1]) - ) + tile_batch_reps = _op.const([batch_size, 1]) box_range_2d = _op.tile(box_range, tile_batch_reps) valid_mask = _op.cast( _op.less(box_range_2d, _op.expand_dims(num_detections, axis=1)), "float32" From 7c43cb580062069aeb6da7d0babc907b2fb258d3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 3 Jun 2021 05:46:15 +0900 Subject: [PATCH 15/17] check batch_size is int --- python/tvm/relay/frontend/tensorflow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 4599974ac1b0..93b75cb8ac5b 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -817,6 +817,7 @@ def convert_combined_nms_with_all_class_nms( box_range = _op.arange( _op.const(0, dtype="int64"), _op.const(max_total_size, dtype="int64"), dtype="int64" ) + assert isinstance(batch_size, int), "dynamic batch size not supported yet." tile_batch_reps = _op.const([batch_size, 1]) box_range_2d = _op.tile(box_range, tile_batch_reps) valid_mask = _op.cast( From 3ed960072e8298f58361b66cf11b27edf92b6595 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 3 Jun 2021 06:31:23 +0900 Subject: [PATCH 16/17] fix dtype issue in scan --- python/tvm/topi/cuda/scan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index 6dbaf02191c8..0d19a92f2058 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -231,7 +231,7 @@ def ir(data, data_ex_scan, reduction): data[tid * scan_axis_size + scan_axis_size - 1], ) with ib.else_scope(): - reduction[tid] = 0 + reduction[tid] = cast(0, reduction.dtype) return ib.get() From ef2eae5809f18d58bc1f920118641447a9942962 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 3 Jun 2021 06:36:22 +0900 Subject: [PATCH 17/17] fix slicing before topk --- python/tvm/relay/frontend/tensorflow.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 93b75cb8ac5b..fdd58bb53ba5 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -842,10 +842,9 @@ def true_branch(): def false_branch(): if isinstance(max_output_boxes_per_class, int): # Do topk on smaller input if possible - # TODO(masahi): use axes argument in strided slice when it becomes available - slice_mx = _op.const([-1, max_output_boxes_per_class * num_class], dtype="int64") + slice_mx = _op.const([max_output_boxes_per_class * num_class], dtype="int64") selected_scores_slice = _op.strided_slice( - selected_scores, begin=_op.const([0, 0], dtype="int64"), end=slice_mx + selected_scores, begin=_op.const([0], dtype="int64"), end=slice_mx, axes=[1] ) else: selected_scores_slice = selected_scores