From 4d11bcf30549aac4b0e22e4c58e42c4c06acbdf9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 23 Dec 2020 05:02:15 +0900 Subject: [PATCH 1/9] add a pattern to rewrite nms to batched nms --- python/tvm/relay/frontend/pytorch_utils.py | 97 ++++++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index d0f0b9b4b019..8db619efa1d9 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=import-outside-toplevel """ Common utilities used by PyTorch frontend """ +from .. import op +from ..dataflow_pattern import * def is_version_greater_than(ver): @@ -25,3 +27,98 @@ def is_version_greater_than(ver): return "".join(re.findall(r"(\d+\.)(\d+\.)(\d)", torch.__version__)[0]) > "".join( re.findall(r"(\d+\.)(\d+\.)(\d)", ver)[0] ) + + +def batched_nms_pattern(boxes, scores, idxs, iou_threshold): + one = is_constant() + zero = is_constant() + + # %1796 = expand_dims(%1795, axis=-1); + score_expand_dims = is_op("expand_dims")(scores) + + # %1824 = cast(%1823, dtype="float32"); + cast = is_op("cast")(idxs) + mx = is_op("max")(boxes) + add = is_op("add")(mx, one) + mul = is_op("multiply")(cast, add) + + # %1828 = cast_like(0, meta[op.Constant][127]); + cast_like = is_op("cast_like")(zero, is_constant()) + less = is_op("less")(is_constant(), cast_like) + shape_of = is_op("shape_of")(mul) + cast_like = is_op("cast_like")(shape_of, is_constant()) + add = is_op("add")(is_constant(), cast_like) + where = is_op("where")(less, add, is_constant()) + shape_of = is_op("shape_of")(mul) + cast = is_op("cast")(shape_of) + + # %1836 = dyn.strided_slice(%1827, %1833, %1835, meta[op.Constant][128], begin=None, end=None, strides=None); + dyn_strided_slice = is_op("dyn.strided_slice")(mul, where, cast, is_constant()) + + expand_dims = is_op("expand_dims")(dyn_strided_slice) + add = is_op("add")(boxes, expand_dims) + tup = is_tuple([score_expand_dims, add]) + concat = is_op("concatenate")(tup) + expand_dims = is_op("expand_dims")(concat) + + # %1842 = vision.get_valid_counts(%1841, -1f, meta[op.attrs.GetValidCountsAttrs][1]); + get_valid_counts_out = is_op("vision.get_valid_counts")(expand_dims, is_constant()) + data = is_tuple_get_item(get_valid_counts_out, 1) + valid_counts = is_tuple_get_item(get_valid_counts_out, 0) + indices = is_tuple_get_item(get_valid_counts_out, 2) + + # %1169 = vision.non_max_suppression(%1166, %1167, %1168, -1, 0.7f, meta[op.attrs.NonMaximumSuppressionAttrs][0]); + return is_op("vision.non_max_suppression")( + data, valid_counts, indices, is_constant(), iou_threshold + ) + + +def convert_batched_nms(boxes, scores, idxs, iou_thres): + scores = op.expand_dims(scores, axis=-1, num_newaxis=1) + idxs = op.expand_dims(idxs, axis=-1, num_newaxis=1) + idxs = op.cast(idxs, "float32") + data = op.concatenate([idxs, scores, boxes], -1) + data = op.expand_dims(data, 0, 1) + ct, data, indices = op.vision.get_valid_counts( + data, score_threshold=-1.0, id_index=0, score_index=1 + ) + top_k = max_out_size = -1 + out = op.vision.non_max_suppression( + data=data, + valid_count=ct, + indices=indices, + max_output_size=max_out_size, + iou_threshold=iou_thres, + force_suppress=True, + top_k=top_k, + coord_start=1, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, + ) + return out.tuple_value + + +class NMSRewrite(DFPatternCallback): + def __init__(self): + super().__init__() + # exprs I want to extract + self.boxes = wildcard() + self.scores = wildcard() + self.idxs = wildcard() + self.iou_threshold = wildcard() + self.pattern = batched_nms_pattern(self.boxes, self.scores, self.idxs, self.iou_threshold) + + def callback(self, _, node_map): + print("matched") + boxes = node_map[self.boxes][0] + scores = node_map[self.scores][0] + idxs = node_map[self.idxs][0] + iou_thres = node_map[self.iou_threshold][0] + return convert_batched_nms(boxes, scores, idxs, iou_thres) + + +def rewrite_nms_to_batched_nms(mod): + mod["main"] = rewrite(NMSRewrite(), mod["main"]) + return mod From 24e40929b1224d2d5f6ed5fca41f40832a39d4a3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 23 Dec 2020 05:04:44 +0900 Subject: [PATCH 2/9] update object detection test to add rewrite --- .../frontend/pytorch/test_object_detection.py | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index e4545ec4ef5e..113fedbd328b 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -26,6 +26,7 @@ import tvm.testing from tvm import relay from tvm.runtime.vm import VirtualMachine +from tvm.relay.frontend.pytorch_utils import rewrite_nms_to_batched_nms from tvm.contrib.download import download @@ -108,15 +109,17 @@ def test_detection_models(): with torch.no_grad(): pt_res = scripted_model(data) - for target in ["llvm", "cuda"]: - with tvm.transform.PassContext(opt_level=3): + def compile_and_run_vm(mod, params, data_np): + with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]): vm_exec = relay.vm.compile(mod, target=target, params=params) - ctx = tvm.context(target, 0) + ctx = tvm.cpu() vm = VirtualMachine(vm_exec, ctx) - vm.set_input("main", **{input_name: data_np}) - tvm_res = vm.run() + return vm.run() + + for target in ["cuda", "llvm"]: + tvm_res = compile_and_run_vm(mod, params, data_np) # Bounding boxes tvm.testing.assert_allclose( @@ -132,3 +135,14 @@ def test_detection_models(): score_threshold = 0.9 print("Num boxes:", pt_res[0].cpu().numpy().shape[0]) print("Num valid boxes:", np.sum(pt_res[1].cpu().numpy() >= score_threshold)) + + before = mod["main"] + mod = rewrite_nms_to_batched_nms(mod) + after = mod["main"] + assert not tvm.ir.structural_equal(after, before) + + tvm_res_after_rewrite = compile_and_run_vm(mod, params, data_np) + + # Results should be equivalent after rewriting + for res1, res2 in zip(tvm_res, tvm_res_after_rewrite): + tvm.testing.assert_allclose(res1, res2) From f375eb6036eed86bebf698df08662e868df4f78b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 23 Dec 2020 05:27:26 +0900 Subject: [PATCH 3/9] updated tutorial --- python/tvm/relay/frontend/pytorch_utils.py | 5 ++--- tutorials/frontend/deploy_object_detection_pytorch.py | 10 +++++++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 8db619efa1d9..0279484af867 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=import-outside-toplevel +# pylint: disable=import-outside-toplevel, unused-argument """ Common utilities used by PyTorch frontend """ from .. import op from ..dataflow_pattern import * @@ -110,8 +110,7 @@ def __init__(self): self.iou_threshold = wildcard() self.pattern = batched_nms_pattern(self.boxes, self.scores, self.idxs, self.iou_threshold) - def callback(self, _, node_map): - print("matched") + def callback(self, pre, post, node_map): boxes = node_map[self.boxes][0] scores = node_map[self.scores][0] idxs = node_map[self.idxs][0] diff --git a/tutorials/frontend/deploy_object_detection_pytorch.py b/tutorials/frontend/deploy_object_detection_pytorch.py index 2852dd3ad99d..58db6c63590e 100644 --- a/tutorials/frontend/deploy_object_detection_pytorch.py +++ b/tutorials/frontend/deploy_object_detection_pytorch.py @@ -42,7 +42,7 @@ import tvm from tvm import relay -from tvm import relay +from tvm.relay.frontend.pytorch_utils import rewrite_nms_to_batched_nms from tvm.runtime.vm import VirtualMachine from tvm.contrib.download import download @@ -115,6 +115,14 @@ def forward(self, inp): shape_list = [(input_name, input_shape)] mod, params = relay.frontend.from_pytorch(script_module, shape_list) +###################################################################### +# Rewrite the graph for more performance +# ------------------------- +# We provide a graph rewrite utility to replace the costly non maximum +# surpression in torchvision that does not take class id into account +# with more efficient one. +mod = rewrite_nms_to_batched_nms(mod) + ###################################################################### # Compile with Relay VM # --------------------- From 6dd8967af687f3bc9cedb6f68618921e5f2fdd84 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 23 Dec 2020 06:50:46 +0900 Subject: [PATCH 4/9] add doc --- python/tvm/relay/frontend/pytorch_utils.py | 62 ++++++++++++---------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 0279484af867..458fbc1115c2 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -30,6 +30,7 @@ def is_version_greater_than(ver): def batched_nms_pattern(boxes, scores, idxs, iou_threshold): + """A pattern to detect batched_nms function in torchvision""" one = is_constant() zero = is_constant() @@ -73,34 +74,9 @@ def batched_nms_pattern(boxes, scores, idxs, iou_threshold): ) -def convert_batched_nms(boxes, scores, idxs, iou_thres): - scores = op.expand_dims(scores, axis=-1, num_newaxis=1) - idxs = op.expand_dims(idxs, axis=-1, num_newaxis=1) - idxs = op.cast(idxs, "float32") - data = op.concatenate([idxs, scores, boxes], -1) - data = op.expand_dims(data, 0, 1) - ct, data, indices = op.vision.get_valid_counts( - data, score_threshold=-1.0, id_index=0, score_index=1 - ) - top_k = max_out_size = -1 - out = op.vision.non_max_suppression( - data=data, - valid_count=ct, - indices=indices, - max_output_size=max_out_size, - iou_threshold=iou_thres, - force_suppress=True, - top_k=top_k, - coord_start=1, - score_index=1, - id_index=0, - return_indices=True, - invalid_to_bottom=False, - ) - return out.tuple_value - - class NMSRewrite(DFPatternCallback): + """A callback to rewrite nms and restore batched nms""" + def __init__(self): super().__init__() # exprs I want to extract @@ -110,14 +86,44 @@ def __init__(self): self.iou_threshold = wildcard() self.pattern = batched_nms_pattern(self.boxes, self.scores, self.idxs, self.iou_threshold) + def convert_batched_nms(self, boxes, scores, idxs, iou_thres): + scores = op.expand_dims(scores, axis=-1, num_newaxis=1) + idxs = op.expand_dims(idxs, axis=-1, num_newaxis=1) + idxs = op.cast(idxs, "float32") + data = op.concatenate([idxs, scores, boxes], -1) + data = op.expand_dims(data, 0, 1) + ct, data, indices = op.vision.get_valid_counts( + data, score_threshold=-1.0, id_index=0, score_index=1 + ) + top_k = max_out_size = -1 + out = op.vision.non_max_suppression( + data=data, + valid_count=ct, + indices=indices, + max_output_size=max_out_size, + iou_threshold=iou_thres, + force_suppress=True, + top_k=top_k, + coord_start=1, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, + ) + return out.tuple_value + def callback(self, pre, post, node_map): boxes = node_map[self.boxes][0] scores = node_map[self.scores][0] idxs = node_map[self.idxs][0] iou_thres = node_map[self.iou_threshold][0] - return convert_batched_nms(boxes, scores, idxs, iou_thres) + return self.convert_batched_nms(boxes, scores, idxs, iou_thres) def rewrite_nms_to_batched_nms(mod): + """Rewrite the input graph to replace the costly non maximum surpression + in torchvision that does not take class id into account with the one + that avoids IOU tests between different classes. + """ mod["main"] = rewrite(NMSRewrite(), mod["main"]) return mod From ee4f8f097c64b5fc3c6aefab8f66d8aa09ca0758 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 23 Dec 2020 07:22:32 +0900 Subject: [PATCH 5/9] fixed coord_start --- python/tvm/relay/frontend/pytorch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 458fbc1115c2..58f576bef45e 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -104,7 +104,7 @@ def convert_batched_nms(self, boxes, scores, idxs, iou_thres): iou_threshold=iou_thres, force_suppress=True, top_k=top_k, - coord_start=1, + coord_start=2, score_index=1, id_index=0, return_indices=True, From e62ada0a4d780e1a53ec2c4e5db479183199eef9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 23 Dec 2020 08:02:29 +0900 Subject: [PATCH 6/9] test fixed by setting force_surpress=False --- python/tvm/relay/frontend/pytorch_utils.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 58f576bef45e..51473bbbabef 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -17,7 +17,15 @@ # pylint: disable=import-outside-toplevel, unused-argument """ Common utilities used by PyTorch frontend """ from .. import op -from ..dataflow_pattern import * +from ..dataflow_pattern import ( + is_constant, + is_op, + rewrite, + is_tuple, + is_tuple_get_item, + wildcard, + DFPatternCallback, +) def is_version_greater_than(ver): @@ -79,7 +87,7 @@ class NMSRewrite(DFPatternCallback): def __init__(self): super().__init__() - # exprs I want to extract + # exprs to extract self.boxes = wildcard() self.scores = wildcard() self.idxs = wildcard() @@ -102,7 +110,7 @@ def convert_batched_nms(self, boxes, scores, idxs, iou_thres): indices=indices, max_output_size=max_out_size, iou_threshold=iou_thres, - force_suppress=True, + force_suppress=False, top_k=top_k, coord_start=2, score_index=1, From 4749f9a5fd47d5c883a170d9f78ec4703f019c85 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 23 Dec 2020 08:35:14 +0900 Subject: [PATCH 7/9] revert tutorial change --- python/tvm/relay/frontend/pytorch_utils.py | 11 +++-------- tutorials/frontend/deploy_object_detection_pytorch.py | 10 +--------- 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 51473bbbabef..31f3be258a27 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=import-outside-toplevel, unused-argument +# pylint: disable=import-outside-toplevel, unused-argument, invalid-name """ Common utilities used by PyTorch frontend """ from .. import op from ..dataflow_pattern import ( @@ -42,16 +42,13 @@ def batched_nms_pattern(boxes, scores, idxs, iou_threshold): one = is_constant() zero = is_constant() - # %1796 = expand_dims(%1795, axis=-1); score_expand_dims = is_op("expand_dims")(scores) - # %1824 = cast(%1823, dtype="float32"); cast = is_op("cast")(idxs) mx = is_op("max")(boxes) add = is_op("add")(mx, one) mul = is_op("multiply")(cast, add) - # %1828 = cast_like(0, meta[op.Constant][127]); cast_like = is_op("cast_like")(zero, is_constant()) less = is_op("less")(is_constant(), cast_like) shape_of = is_op("shape_of")(mul) @@ -61,7 +58,6 @@ def batched_nms_pattern(boxes, scores, idxs, iou_threshold): shape_of = is_op("shape_of")(mul) cast = is_op("cast")(shape_of) - # %1836 = dyn.strided_slice(%1827, %1833, %1835, meta[op.Constant][128], begin=None, end=None, strides=None); dyn_strided_slice = is_op("dyn.strided_slice")(mul, where, cast, is_constant()) expand_dims = is_op("expand_dims")(dyn_strided_slice) @@ -70,13 +66,11 @@ def batched_nms_pattern(boxes, scores, idxs, iou_threshold): concat = is_op("concatenate")(tup) expand_dims = is_op("expand_dims")(concat) - # %1842 = vision.get_valid_counts(%1841, -1f, meta[op.attrs.GetValidCountsAttrs][1]); get_valid_counts_out = is_op("vision.get_valid_counts")(expand_dims, is_constant()) data = is_tuple_get_item(get_valid_counts_out, 1) valid_counts = is_tuple_get_item(get_valid_counts_out, 0) indices = is_tuple_get_item(get_valid_counts_out, 2) - # %1169 = vision.non_max_suppression(%1166, %1167, %1168, -1, 0.7f, meta[op.attrs.NonMaximumSuppressionAttrs][0]); return is_op("vision.non_max_suppression")( data, valid_counts, indices, is_constant(), iou_threshold ) @@ -95,6 +89,7 @@ def __init__(self): self.pattern = batched_nms_pattern(self.boxes, self.scores, self.idxs, self.iou_threshold) def convert_batched_nms(self, boxes, scores, idxs, iou_thres): + """Restore class-aware NMS using extracted class indices""" scores = op.expand_dims(scores, axis=-1, num_newaxis=1) idxs = op.expand_dims(idxs, axis=-1, num_newaxis=1) idxs = op.cast(idxs, "float32") @@ -129,7 +124,7 @@ def callback(self, pre, post, node_map): def rewrite_nms_to_batched_nms(mod): - """Rewrite the input graph to replace the costly non maximum surpression + """Rewrite the input graph to replace non maximum surpression in torchvision that does not take class id into account with the one that avoids IOU tests between different classes. """ diff --git a/tutorials/frontend/deploy_object_detection_pytorch.py b/tutorials/frontend/deploy_object_detection_pytorch.py index 58db6c63590e..2852dd3ad99d 100644 --- a/tutorials/frontend/deploy_object_detection_pytorch.py +++ b/tutorials/frontend/deploy_object_detection_pytorch.py @@ -42,7 +42,7 @@ import tvm from tvm import relay -from tvm.relay.frontend.pytorch_utils import rewrite_nms_to_batched_nms +from tvm import relay from tvm.runtime.vm import VirtualMachine from tvm.contrib.download import download @@ -115,14 +115,6 @@ def forward(self, inp): shape_list = [(input_name, input_shape)] mod, params = relay.frontend.from_pytorch(script_module, shape_list) -###################################################################### -# Rewrite the graph for more performance -# ------------------------- -# We provide a graph rewrite utility to replace the costly non maximum -# surpression in torchvision that does not take class id into account -# with more efficient one. -mod = rewrite_nms_to_batched_nms(mod) - ###################################################################### # Compile with Relay VM # --------------------- From bcdf77418eb6225c8726fa29a1bd889b96b49aa9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 24 Dec 2020 05:19:31 +0900 Subject: [PATCH 8/9] add some comment to explain the pattern --- python/tvm/relay/frontend/pytorch_utils.py | 48 ++++++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 31f3be258a27..97f56e11c4d4 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -38,17 +38,53 @@ def is_version_greater_than(ver): def batched_nms_pattern(boxes, scores, idxs, iou_threshold): - """A pattern to detect batched_nms function in torchvision""" + """A pattern to detect batched_nms function in torchvision + + The inputs to this function, boxes, scores, idxs, iou_threshold are wildcard + patterns which can be used later in the rewriting to extract matched Relay fragments. + + We want to detect the following PyTorch code snippet: + + def batched_nms(boxes, scores, idxs, iou_threshold): + max_coordinate = boxes.max() + offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) + boxes_for_nms = boxes + offsets[:, None] + keep = nms(boxes_for_nms, scores, iou_threshold) + return keep + + Here is how PyTorch frontend lowers above PyTorch code. For simplicity, Relay ops for + dealing with dynamic strided_slice are omitted. + + %2 = expand_dims(%scores, axis=-1); + %3 = cast(%idxs, dtype="float32"); + %4 = max(%boxes); + %5 = add(%4, 1f); + %6 = multiply(%3, %5); + %7 = strided_slice(%6, begin=[0], end=[4507], strides=[1]); + %8 = expand_dims(%7, axis=1); + %9 = add(%boxes, %8); + %10 = (%2, %9); + %11 = concatenate(%10, axis=-1); + %12 = expand_dims(%11, axis=0); + %13 = vision.get_valid_counts(%12, -1f, meta[relay.attrs.GetValidCountsAttrs][0]); + %14 = %13.1; + %15 = %13.0; + %16 = %13.2; + %17 = vision.non_max_suppression(%14, %15, %16, -1, 0.7f, ...); + + """ one = is_constant() zero = is_constant() - score_expand_dims = is_op("expand_dims")(scores) - + # Equivelent PyTorch code from above snippet + # offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) cast = is_op("cast")(idxs) mx = is_op("max")(boxes) add = is_op("add")(mx, one) mul = is_op("multiply")(cast, add) + # The following doesn't appear in the above Relay snippet. It is required for dynamic + # stride_slice handling cast_like = is_op("cast_like")(zero, is_constant()) less = is_op("less")(is_constant(), cast_like) shape_of = is_op("shape_of")(mul) @@ -58,10 +94,16 @@ def batched_nms_pattern(boxes, scores, idxs, iou_threshold): shape_of = is_op("shape_of")(mul) cast = is_op("cast")(shape_of) + # This corresponds to offsets[:, None], where offsets is the result of multiplication dyn_strided_slice = is_op("dyn.strided_slice")(mul, where, cast, is_constant()) + # Add offsets to the boxes expand_dims = is_op("expand_dims")(dyn_strided_slice) add = is_op("add")(boxes, expand_dims) + + # The rest of patterns correspond to the PyTorch frontend conversion + # function for torchvision::nms + score_expand_dims = is_op("expand_dims")(scores) tup = is_tuple([score_expand_dims, add]) concat = is_op("concatenate")(tup) expand_dims = is_op("expand_dims")(concat) From af38f4d05ed850e528099dedbb40e722bfb1ca84 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 26 Dec 2020 11:21:33 +0900 Subject: [PATCH 9/9] update NMS pattern following frontend change --- python/tvm/relay/frontend/pytorch.py | 14 +++--- python/tvm/relay/frontend/pytorch_utils.py | 48 ++++++++++--------- .../frontend/pytorch/test_object_detection.py | 12 ++--- 3 files changed, 39 insertions(+), 35 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 94ee9282e4fa..cf2cbad0a78e 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1857,18 +1857,18 @@ def nms(self, inputs, input_types): scores = inputs[1] iou_threshold = inputs[2] - num_boxes = _op.shape_of(scores) - # TVM NMS assumes score > 0 scores = scores - _op.min(scores) + _op.const(1.0) + + num_boxes = _op.shape_of(scores) + # PyTorch NMS doesn't have score_threshold, so no need to run get_valid_count + indices = _op.transform.arange(_op.squeeze(num_boxes), dtype="int32") + indices = _op.expand_dims(indices, 0, 1) + # Generate data with shape (1, num_anchors, 5) scores = AttrCvt(op_name="expand_dims", extras={"axis": -1, "num_newaxis": 1})([scores], {}) data = _op.concatenate([scores, boxes], -1) data = _op.expand_dims(data, 0, 1) - # PyTorch NMS doesn't have score_threshold, so no need to run get_valid_count - indices = _op.transform.arange(_op.squeeze(num_boxes), dtype="int32") - indices = _op.expand_dims(indices, 0, 1) - ct = num_boxes # Perform Non-Maximum Suppression, # PyTorch NMS doesn't have parameter top_k and max_output_size @@ -1876,7 +1876,7 @@ def nms(self, inputs, input_types): top_k = max_out_size = -1 nms_ret = get_relay_op("non_max_suppression")( data=data, - valid_count=ct, + valid_count=num_boxes, indices=indices, max_output_size=max_out_size, iou_threshold=iou_threshold, diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index 97f56e11c4d4..6fc5a6af4a36 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -22,7 +22,6 @@ is_op, rewrite, is_tuple, - is_tuple_get_item, wildcard, DFPatternCallback, ) @@ -37,7 +36,7 @@ def is_version_greater_than(ver): ) -def batched_nms_pattern(boxes, scores, idxs, iou_threshold): +def batched_nms_pattern(boxes, scores, idxs, iou_threshold, num_boxes, indices): """A pattern to detect batched_nms function in torchvision The inputs to this function, boxes, scores, idxs, iou_threshold are wildcard @@ -53,7 +52,9 @@ def batched_nms(boxes, scores, idxs, iou_threshold): return keep Here is how PyTorch frontend lowers above PyTorch code. For simplicity, Relay ops for - dealing with dynamic strided_slice are omitted. + dealing with dynamic strided_slice are omitted. %num_boxes, %indices are complex + expressions, but since we can use the wildcard part for them, we do not need to construct + their patterns. %2 = expand_dims(%scores, axis=-1); %3 = cast(%idxs, dtype="float32"); @@ -66,11 +67,9 @@ def batched_nms(boxes, scores, idxs, iou_threshold): %10 = (%2, %9); %11 = concatenate(%10, axis=-1); %12 = expand_dims(%11, axis=0); - %13 = vision.get_valid_counts(%12, -1f, meta[relay.attrs.GetValidCountsAttrs][0]); - %14 = %13.1; - %15 = %13.0; - %16 = %13.2; - %17 = vision.non_max_suppression(%14, %15, %16, -1, 0.7f, ...); + ... + ... + %17 = vision.non_max_suppression(%12, %num_boxes, %indices, -1, 0.7f, ...); """ one = is_constant() @@ -106,15 +105,10 @@ def batched_nms(boxes, scores, idxs, iou_threshold): score_expand_dims = is_op("expand_dims")(scores) tup = is_tuple([score_expand_dims, add]) concat = is_op("concatenate")(tup) - expand_dims = is_op("expand_dims")(concat) - - get_valid_counts_out = is_op("vision.get_valid_counts")(expand_dims, is_constant()) - data = is_tuple_get_item(get_valid_counts_out, 1) - valid_counts = is_tuple_get_item(get_valid_counts_out, 0) - indices = is_tuple_get_item(get_valid_counts_out, 2) + data = is_op("expand_dims")(concat) return is_op("vision.non_max_suppression")( - data, valid_counts, indices, is_constant(), iou_threshold + data, num_boxes, indices, is_constant(), iou_threshold ) @@ -128,22 +122,30 @@ def __init__(self): self.scores = wildcard() self.idxs = wildcard() self.iou_threshold = wildcard() - self.pattern = batched_nms_pattern(self.boxes, self.scores, self.idxs, self.iou_threshold) + self.num_boxes = wildcard() + self.indices = wildcard() + + self.pattern = batched_nms_pattern( + self.boxes, + self.scores, + self.idxs, + self.iou_threshold, + self.num_boxes, + self.indices, + ) - def convert_batched_nms(self, boxes, scores, idxs, iou_thres): + def convert_batched_nms(self, boxes, scores, idxs, iou_thres, num_boxes, indices): """Restore class-aware NMS using extracted class indices""" scores = op.expand_dims(scores, axis=-1, num_newaxis=1) idxs = op.expand_dims(idxs, axis=-1, num_newaxis=1) idxs = op.cast(idxs, "float32") data = op.concatenate([idxs, scores, boxes], -1) data = op.expand_dims(data, 0, 1) - ct, data, indices = op.vision.get_valid_counts( - data, score_threshold=-1.0, id_index=0, score_index=1 - ) + top_k = max_out_size = -1 out = op.vision.non_max_suppression( data=data, - valid_count=ct, + valid_count=num_boxes, indices=indices, max_output_size=max_out_size, iou_threshold=iou_thres, @@ -162,7 +164,9 @@ def callback(self, pre, post, node_map): scores = node_map[self.scores][0] idxs = node_map[self.idxs][0] iou_thres = node_map[self.iou_threshold][0] - return self.convert_batched_nms(boxes, scores, idxs, iou_thres) + num_boxes = node_map[self.num_boxes][0] + indices = node_map[self.indices][0] + return self.convert_batched_nms(boxes, scores, idxs, iou_thres, num_boxes, indices) def rewrite_nms_to_batched_nms(mod): diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py index 113fedbd328b..2c323776f087 100644 --- a/tests/python/frontend/pytorch/test_object_detection.py +++ b/tests/python/frontend/pytorch/test_object_detection.py @@ -109,17 +109,17 @@ def test_detection_models(): with torch.no_grad(): pt_res = scripted_model(data) - def compile_and_run_vm(mod, params, data_np): - with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]): + def compile_and_run_vm(mod, params, data_np, target): + with tvm.transform.PassContext(opt_level=3): vm_exec = relay.vm.compile(mod, target=target, params=params) - ctx = tvm.cpu() + ctx = tvm.context(target, 0) vm = VirtualMachine(vm_exec, ctx) vm.set_input("main", **{input_name: data_np}) return vm.run() for target in ["cuda", "llvm"]: - tvm_res = compile_and_run_vm(mod, params, data_np) + tvm_res = compile_and_run_vm(mod, params, data_np, target) # Bounding boxes tvm.testing.assert_allclose( @@ -141,8 +141,8 @@ def compile_and_run_vm(mod, params, data_np): after = mod["main"] assert not tvm.ir.structural_equal(after, before) - tvm_res_after_rewrite = compile_and_run_vm(mod, params, data_np) + tvm_res_after_rewrite = compile_and_run_vm(mod, params, data_np, "llvm") # Results should be equivalent after rewriting for res1, res2 in zip(tvm_res, tvm_res_after_rewrite): - tvm.testing.assert_allclose(res1, res2) + tvm.testing.assert_allclose(res1.asnumpy(), res2.asnumpy())