From 494f5444a7a5aa628ff73394715be0a6a36df82c Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Wed, 11 Dec 2019 10:34:38 +0000 Subject: [PATCH 01/10] [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess This adds support for the custom operator TFLite_Detection_PostProcess which is commonly used in object detection networks such as SSD Mobilenet. It only adds support for when use_regular_nms = False. Change-Id: I819b253c0eb6f0fa55da65d2634e09359b888828 --- python/tvm/relay/frontend/tflite.py | 196 ++++++++++++++++++++++++++++ 1 file changed, 196 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index d3826b6ce52d..bba33a6f52a7 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -121,6 +121,7 @@ def __init__(self, model, subgraph, exp_tab): 'SQUARED_DIFFERENCE': self.convert_squared_difference, 'LOGICAL_AND': self.convert_logical_and, 'LOGICAL_OR': self.convert_logical_or, + 'DETECTION_POSTPROCESS': self._convert_detection_postprocess, } def check_unsupported_ops(self): @@ -168,6 +169,10 @@ def get_op_code_str(self, op): op_code_str = self.builtin_op_code[op_code_id] if op_code_id == BuiltinOperator.CUSTOM: # Custom operator + custom_op_code_str = self.model.OperatorCodes(op_code_list_idx).CustomCode() + if custom_op_code_str == b'TFLite_Detection_PostProcess': + return "DETECTION_POSTPROCESS" + raise NotImplementedError("Custom operators are currently not supported") return op_code_str @@ -1806,6 +1811,112 @@ def convert_transpose_conv(self, op): return out + def _convert_detection_postprocess(self, op): + """Convert TFLite_Detection_PostProcess""" + _option_names = [ + "w_scale", + "max_detections", + "_output_quantized", + "detections_per_class", + "x_scale", + "nms_score_threshold", + "num_classes", + "max_classes_per_detection", + "use_regular_nms", + "y_scale", + "h_scale", + "_support_output_type_float_in_quantized_op", + "nms_iou_threshold" + ] + + custom_options = get_custom_options(op, _option_names) + if custom_options["use_regular_nms"]: + raise tvm.error.OpAttributeUnImplemented( + "use_regular_nms=True is not yet supported for operator {}." + .format("TFLite_Detection_PostProcess") + ) + + inputs = self.get_input_tensors(op) + cls_pred = self.get_expr(inputs[1].tensor_idx) + loc_prob = self.get_expr(inputs[0].tensor_idx) + anchor_values = self.get_tensor_value(inputs[2]) + anchor_boxes = len(anchor_values) + anchor_type = self.get_tensor_type_str(inputs[2].tensor.Type()) + anchor_expr = self.exp_tab.new_const(anchor_values, dtype=anchor_type) + + if inputs[0].qnn_params: + loc_prob = _qnn.op.dequantize(data=loc_prob, + input_scale=inputs[0].qnn_params['scale'], + input_zero_point=inputs[0].qnn_params['zero_point']) + if inputs[1].qnn_params: + cls_pred = _qnn.op.dequantize(data=cls_pred, + input_scale=inputs[1].qnn_params['scale'], + input_zero_point=inputs[1].qnn_params['zero_point']) + if inputs[2].qnn_params: + anchor_expr = _qnn.op.dequantize(data=anchor_expr, + input_scale=inputs[2].qnn_params['scale'], + input_zero_point=inputs[2].qnn_params['zero_point']) + + # reshape the cls_pred and loc_prob tensors so + # they can be consumed by multibox_transform_loc + cls_pred = _op.transpose(cls_pred, [0, 2, 1]) + # loc_prob coords are in yxhw format + # need to convert to xywh + loc_coords = _op.split(loc_prob, 4, axis=2) + loc_prob = _op.concatenate( + [loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], axis=2 + ) + loc_prob = _op.reshape(loc_prob, [1, anchor_boxes*4]) + + # anchor coords are in yxhw format + # need to convert to ltrb + anchor_coords = _op.split(anchor_expr, 4, axis=1) + anchor_y = anchor_coords[0] + anchor_x = anchor_coords[1] + anchor_h = anchor_coords[2] + anchor_w = anchor_coords[3] + plus_half = _expr.const(0.5, dtype='float32') + minus_half = _expr.const(-0.5, dtype='float32') + anchor_l = _op.add(anchor_x, _op.multiply(anchor_w, minus_half)) + anchor_r = _op.add(anchor_x, _op.multiply(anchor_w, plus_half)) + anchor_t = _op.add(anchor_y, _op.multiply(anchor_h, minus_half)) + anchor_b = _op.add(anchor_y, _op.multiply(anchor_h, plus_half)) + anchor_expr = _op.concatenate([anchor_l, anchor_t, anchor_r, anchor_b], axis=1) + anchor_expr = _op.expand_dims(anchor_expr, 0) + + # attributes for multibox_transform_loc + new_attrs0 = {} + new_attrs0["clip"] = False + new_attrs0["threshold"] = custom_options["nms_score_threshold"] + new_attrs0["variances"] = ( + 1/custom_options["x_scale"], + 1/custom_options["y_scale"], + 1/custom_options["w_scale"], + 1/custom_options["h_scale"], + ) + + # attributes for non_max_suppression + new_attrs1 = {} + new_attrs1["return_indices"] = False + new_attrs1["iou_threshold"] = custom_options["nms_iou_threshold"] + new_attrs1["force_suppress"] = False + new_attrs1["top_k"] = anchor_boxes + new_attrs1["max_output_size"] = custom_options["max_detections"] + new_attrs1["invalid_to_bottom"] = False + + ret = _op.vision.multibox_transform_loc(cls_pred, loc_prob, + anchor_expr, **new_attrs0) + ret = _op.vision.non_max_suppression(ret[0], ret[1], **new_attrs1) + ret = _op.vision.get_valid_counts(ret, 0) + valid_count = ret[0] + # the output needs some reshaping to match tflite + ret = _op.split(ret[1], 6, axis=2) + cls_ids = ret[0] + scores = ret[1] + boxes = _op.concatenate([ret[3], ret[2], ret[5], ret[4]], axis=2) + ret = _expr.TupleWrapper(_expr.Tuple([boxes, cls_ids, scores, valid_count]), size=4) + return ret + def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) @@ -1877,6 +1988,91 @@ def get_tensor_name(subgraph, tensor_idx): return subgraph.Tensors(tensor_idx).Name().decode("utf-8") +def get_custom_options(op, option_names): + """Get the options of a custom operator. + + This implements partial flexbuffer deserialization to be able + to read custom options. It is not intended to be a general + purpose flexbuffer deserializer and as such only supports a + limited number of types and assumes the data is a flat map. + + Parameters + ---------- + op: + A custom TFlite operator. + option_names: list + A complete list of the custom option names. + + Returns + ------- + options: dict + A dictionary of the custom options. + + """ + import struct + from enum import IntEnum + + class _FlexBufferType(IntEnum): + """Flexbuffer type schema from flexbuffers.h""" + FBT_NULL = 0 + FBT_INT = 1 + FBT_UINT = 2 + FBT_FLOAT = 3 + # Types above stored inline, types below store an offset. + FBT_KEY = 4 + FBT_STRING = 5 + FBT_INDIRECT_INT = 6 + FBT_INDIRECT_UINT = 7 + FBT_INDIRECT_FLOAT = 8 + FBT_MAP = 9 + FBT_VECTOR = 10 # Untyped. + FBT_VECTOR_INT = 11 # Typed any size (stores no type table). + FBT_VECTOR_UINT = 12 + FBT_VECTOR_FLOAT = 13 + FBT_VECTOR_KEY = 14 + FBT_VECTOR_STRING = 15 + FBT_VECTOR_INT2 = 16 # Typed tuple (no type table, no size field). + FBT_VECTOR_UINT2 = 17 + FBT_VECTOR_FLOAT2 = 18 + FBT_VECTOR_INT3 = 19 # Typed triple (no type table, no size field). + FBT_VECTOR_UINT3 = 20 + FBT_VECTOR_FLOAT3 = 21 + FBT_VECTOR_INT4 = 22 # Typed quad (no type table, no size field). + FBT_VECTOR_UINT4 = 23 + FBT_VECTOR_FLOAT4 = 24 + FBT_BLOB = 25 + FBT_BOOL = 26 + FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type + + buffer = op.CustomOptionsAsNumpy().tobytes() + value_vector_offset = buffer[-3] + buffer = buffer[:-3] + num_bytes = 4 # Assume all values are stored in 32 bit width + value_vector_size = struct.unpack( + "> 2) + value_offset = -value_vector_offset + i*num_bytes + value_bytes = buffer[value_offset:value_offset+num_bytes] + if flex_type == _FlexBufferType.FBT_BOOL: + value = True if value_bytes[0] else False + if flex_type == _FlexBufferType.FBT_INT: + value = struct.unpack(" Date: Wed, 18 Dec 2019 09:22:07 +0000 Subject: [PATCH 02/10] Added a test for the tflite custom op Change-Id: Ie5baa092deae9a8bcffd2ebd9f6d346b90e58afd --- tests/python/frontend/tflite/test_forward.py | 46 ++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index ad1abc247f7e..54eb017fecaa 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1353,6 +1353,49 @@ def test_forward_fully_connected(): _test_fully_connected([5, 1, 1, 150], [150, 100], [100]) +####################################################################### +# Custom Operators +# ------- + +def test_detection_postprocess(): + tf_model_file = tf_testing.get_workload_official( + "http://download.tensorflow.org/models/object_detection/" + "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz", + "ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/tflite_graph.pb" + ) + converter = tf.lite.TFLiteConverter.from_frozen_graph( + tf_model_file, + input_arrays=["raw_outputs/box_encodings", "raw_outputs/class_predictions"], + output_arrays=[ + "TFLite_Detection_PostProcess", + "TFLite_Detection_PostProcess:1", + "TFLite_Detection_PostProcess:2", + "TFLite_Detection_PostProcess:3" + ], + input_shapes={ + "raw_outputs/box_encodings": (1, 1917, 4), + "raw_outputs/class_predictions": (1, 1917, 91), + }, + ) + converter.allow_custom_ops = True + converter.inference_type = tf.lite.constants.FLOAT + tflite_model = converter.convert() + box_encodings = np.random.uniform(size=(1, 1917, 4)).astype('float32') + class_predictions = np.random.uniform(size=(1, 1917, 91)).astype('float32') + tflite_output = run_tflite_graph(tflite_model, [box_encodings, class_predictions]) + tvm_output = run_tvm_graph(tflite_model, [box_encodings, class_predictions], ["raw_outputs/box_encodings", "raw_outputs/class_predictions"], num_output=4) + # check valid count is the same + assert tvm_output[3] == tflite_output[3] + valid_count = tvm_output[3][0] + tvm_boxes = tvm_output[0][0][:valid_count] + tvm_classes = tvm_output[1][0][:valid_count] + tvm_scores = tvm_output[2][0][:valid_count] + # check the output data is correct + tvm.testing.assert_allclose(np.squeeze(tvm_boxes), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) + tvm.testing.assert_allclose(np.squeeze(tvm_classes), np.squeeze(tflite_output[1]), rtol=1e-5, atol=1e-5) + tvm.testing.assert_allclose(np.squeeze(tvm_scores), np.squeeze(tflite_output[2]), rtol=1e-5, atol=1e-5) + + ####################################################################### # Mobilenet # --------- @@ -1573,6 +1616,9 @@ def test_forward_mediapipe_hand_landmark(): # Logical test_all_logical() + # Detection_PostProcess + test_detection_postprocess() + # End to End test_forward_mobilenet_v1() test_forward_mobilenet_v2() From eeb611e553052177eebc87b1e84583ee5baef257 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Mon, 6 Jan 2020 08:32:35 +0000 Subject: [PATCH 03/10] Removed trailing comma Change-Id: Ib08f02b5f1a59a883048bfb36e4321152cd2e7f2 --- python/tvm/relay/frontend/tflite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index bba33a6f52a7..cfdc91c7c388 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -121,7 +121,7 @@ def __init__(self, model, subgraph, exp_tab): 'SQUARED_DIFFERENCE': self.convert_squared_difference, 'LOGICAL_AND': self.convert_logical_and, 'LOGICAL_OR': self.convert_logical_or, - 'DETECTION_POSTPROCESS': self._convert_detection_postprocess, + 'DETECTION_POSTPROCESS': self._convert_detection_postprocess } def check_unsupported_ops(self): From aa678cf0014fc51c2dd5069b50f0c4ef3be4943b Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Mon, 6 Jan 2020 08:33:38 +0000 Subject: [PATCH 04/10] Added spaces between divide Change-Id: If1171fc03d211a809cedeb800804394972af4060 --- python/tvm/relay/frontend/tflite.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index cfdc91c7c388..ae4b83acc1bd 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1889,10 +1889,10 @@ def _convert_detection_postprocess(self, op): new_attrs0["clip"] = False new_attrs0["threshold"] = custom_options["nms_score_threshold"] new_attrs0["variances"] = ( - 1/custom_options["x_scale"], - 1/custom_options["y_scale"], - 1/custom_options["w_scale"], - 1/custom_options["h_scale"], + 1 / custom_options["x_scale"], + 1 / custom_options["y_scale"], + 1 / custom_options["w_scale"], + 1 / custom_options["h_scale"], ) # attributes for non_max_suppression From 505e7f0d656589a44cb0cc7f6d4d8f38ef84a95e Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Mon, 6 Jan 2020 08:34:51 +0000 Subject: [PATCH 05/10] Formatted comment Change-Id: I3ce7e69b8d2c73aec57369c1c64ea1eec07f087b --- tests/python/frontend/tflite/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 54eb017fecaa..a595d53d2bf1 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1355,7 +1355,7 @@ def test_forward_fully_connected(): ####################################################################### # Custom Operators -# ------- +# ---------------- def test_detection_postprocess(): tf_model_file = tf_testing.get_workload_official( From a7fc2a361692811e0eb8b0061782cd03079f29b3 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Mon, 6 Jan 2020 08:37:14 +0000 Subject: [PATCH 06/10] Reduced line length in test Change-Id: I49eaafc3369070f8f3e85fbb965ad20972096c68 --- tests/python/frontend/tflite/test_forward.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index a595d53d2bf1..875c77e11c2e 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1383,7 +1383,8 @@ def test_detection_postprocess(): box_encodings = np.random.uniform(size=(1, 1917, 4)).astype('float32') class_predictions = np.random.uniform(size=(1, 1917, 91)).astype('float32') tflite_output = run_tflite_graph(tflite_model, [box_encodings, class_predictions]) - tvm_output = run_tvm_graph(tflite_model, [box_encodings, class_predictions], ["raw_outputs/box_encodings", "raw_outputs/class_predictions"], num_output=4) + tvm_output = run_tvm_graph(tflite_model, [box_encodings, class_predictions], + ["raw_outputs/box_encodings", "raw_outputs/class_predictions"], num_output=4) # check valid count is the same assert tvm_output[3] == tflite_output[3] valid_count = tvm_output[3][0] From 5f9cba63f0338fe38ba1abdc54b2896db01e1803 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Mon, 6 Jan 2020 11:44:48 +0000 Subject: [PATCH 07/10] Set random seed for test Change-Id: I542a787d11422ea83c52147b2cb1144fcef0dd77 --- tests/python/frontend/tflite/test_forward.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 875c77e11c2e..db8cf78e6dd2 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1380,6 +1380,7 @@ def test_detection_postprocess(): converter.allow_custom_ops = True converter.inference_type = tf.lite.constants.FLOAT tflite_model = converter.convert() + np.random.seed(0) box_encodings = np.random.uniform(size=(1, 1917, 4)).astype('float32') class_predictions = np.random.uniform(size=(1, 1917, 91)).astype('float32') tflite_output = run_tflite_graph(tflite_model, [box_encodings, class_predictions]) From 7fca3457cf2ddf4aba9ec807ed01d6851dbc95b2 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Mon, 27 Jan 2020 10:52:38 +0000 Subject: [PATCH 08/10] Fixes to style Change-Id: I2971b8ecebe08c882b2481a99f67cfbe515e0b1f --- python/tvm/relay/frontend/tflite.py | 30 ++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index ae4b83acc1bd..ce2aaa6e947c 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -121,7 +121,7 @@ def __init__(self, model, subgraph, exp_tab): 'SQUARED_DIFFERENCE': self.convert_squared_difference, 'LOGICAL_AND': self.convert_logical_and, 'LOGICAL_OR': self.convert_logical_or, - 'DETECTION_POSTPROCESS': self._convert_detection_postprocess + 'DETECTION_POSTPROCESS': self.convert_detection_postprocess } def check_unsupported_ops(self): @@ -1811,7 +1811,7 @@ def convert_transpose_conv(self, op): return out - def _convert_detection_postprocess(self, op): + def convert_detection_postprocess(self, op): """Convert TFLite_Detection_PostProcess""" _option_names = [ "w_scale", @@ -1885,10 +1885,10 @@ def _convert_detection_postprocess(self, op): anchor_expr = _op.expand_dims(anchor_expr, 0) # attributes for multibox_transform_loc - new_attrs0 = {} - new_attrs0["clip"] = False - new_attrs0["threshold"] = custom_options["nms_score_threshold"] - new_attrs0["variances"] = ( + multibox_transform_loc_attrs = {} + multibox_transform_loc_attrs["clip"] = False + multibox_transform_loc_attrs["threshold"] = custom_options["nms_score_threshold"] + multibox_transform_loc_attrs["variances"] = ( 1 / custom_options["x_scale"], 1 / custom_options["y_scale"], 1 / custom_options["w_scale"], @@ -1896,17 +1896,17 @@ def _convert_detection_postprocess(self, op): ) # attributes for non_max_suppression - new_attrs1 = {} - new_attrs1["return_indices"] = False - new_attrs1["iou_threshold"] = custom_options["nms_iou_threshold"] - new_attrs1["force_suppress"] = False - new_attrs1["top_k"] = anchor_boxes - new_attrs1["max_output_size"] = custom_options["max_detections"] - new_attrs1["invalid_to_bottom"] = False + non_max_suppression_attrs = {} + non_max_suppression_attrs["return_indices"] = False + non_max_suppression_attrs["iou_threshold"] = custom_options["nms_iou_threshold"] + non_max_suppression_attrs["force_suppress"] = False + non_max_suppression_attrs["top_k"] = anchor_boxes + non_max_suppression_attrs["max_output_size"] = custom_options["max_detections"] + non_max_suppression_attrs["invalid_to_bottom"] = False ret = _op.vision.multibox_transform_loc(cls_pred, loc_prob, - anchor_expr, **new_attrs0) - ret = _op.vision.non_max_suppression(ret[0], ret[1], **new_attrs1) + anchor_expr, **multibox_transform_loc_attrs) + ret = _op.vision.non_max_suppression(ret[0], ret[1], **non_max_suppression_attrs) ret = _op.vision.get_valid_counts(ret, 0) valid_count = ret[0] # the output needs some reshaping to match tflite From 7e698397318a018022dc3127e2f68faedc250800 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Mon, 3 Feb 2020 16:14:11 +0000 Subject: [PATCH 09/10] Assert for incorrect number of inputs Change-Id: I393f3b3b62be73e427498d98456fb1d5a214e0af --- python/tvm/relay/frontend/tflite.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index ce2aaa6e947c..fbfa1fcd1563 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1837,6 +1837,7 @@ def convert_detection_postprocess(self, op): ) inputs = self.get_input_tensors(op) + assert len(inputs) == 3, "inputs length should be 3" cls_pred = self.get_expr(inputs[1].tensor_idx) loc_prob = self.get_expr(inputs[0].tensor_idx) anchor_values = self.get_tensor_value(inputs[2]) From 5b7f549cace24a123f14fddb5c130ccf60454026 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Mon, 10 Feb 2020 12:08:09 +0000 Subject: [PATCH 10/10] Change comparison to pass linting The linter was updated, so I needed to fix a small style issue as a result. Change-Id: Ia3c954565a00de92e7fb1912eae9ed9875d60c7c --- python/tvm/relay/frontend/tflite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index fbfa1fcd1563..5650d99ab350 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2060,7 +2060,7 @@ class _FlexBufferType(IntEnum): value_offset = -value_vector_offset + i*num_bytes value_bytes = buffer[value_offset:value_offset+num_bytes] if flex_type == _FlexBufferType.FBT_BOOL: - value = True if value_bytes[0] else False + value = bool(value_bytes[0]) if flex_type == _FlexBufferType.FBT_INT: value = struct.unpack("