From 15e16561212ad6c0b226896b18f104bc674a6672 Mon Sep 17 00:00:00 2001 From: anijain2305 Date: Fri, 24 Apr 2020 19:32:15 +0000 Subject: [PATCH 1/8] TFlite e2e FP32 Object detection model --- tests/python/frontend/tflite/test_forward.py | 24 ++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 283d87d5078a..a67cd9b7c51d 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1942,6 +1942,29 @@ def test_forward_qnn_mobilenet_v3_net(): tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) +####################################################################### +# SSD Mobilenet +# ------------- + +def test_forward_coco_ssd_mobilenet_v1(): + """Test the quantized Coco SSD Mobilenet V1 TF Lite model.""" + tflite_model_file = tf_testing.get_workload_official( + "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tgz", + "ssd_mobilenet_v1_coco_2018_01_28.tflite") + + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + + np.random.seed(0) + data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32') + tflite_output = run_tflite_graph(tflite_model_buf, data) + tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=2) + for i in range(2): + tvm.testing.assert_allclose(np.squeeze(tvm_output[i]), np.squeeze(tflite_output[i]), + rtol=1e-5, atol=2e-5) +>>>>>>> TFlite e2e FP32 Object detection model + + ####################################################################### # MediaPipe # ------------- @@ -2038,6 +2061,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_mobilenet_v3() test_forward_inception_v3_net() test_forward_inception_v4_net() + test_forward_coco_ssd_mobilenet_v1() test_forward_mediapipe_hand_landmark() # End to End quantized From f4ad2a5086aaf72ce5e0fd68d7f44070cfed10ba Mon Sep 17 00:00:00 2001 From: anijain2305 Date: Wed, 29 Apr 2020 19:22:38 +0000 Subject: [PATCH 2/8] Fix test --- tests/python/frontend/tflite/test_forward.py | 65 +++++++++++++------- 1 file changed, 43 insertions(+), 22 deletions(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index a67cd9b7c51d..4c863d44dc0c 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1741,23 +1741,27 @@ def test_detection_postprocess(): tflite_output = run_tflite_graph(tflite_model, [box_encodings, class_predictions]) tvm_output = run_tvm_graph(tflite_model, [box_encodings, class_predictions], ["raw_outputs/box_encodings", "raw_outputs/class_predictions"], num_output=4) - # check valid count is the same + + # Check all output shapes are equal + assert all([tvm_tensor.shape == tflite_tensor.shape \ + for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)]) + + # Check valid count is the same assert tvm_output[3] == tflite_output[3] - # check all the output shapes are the same - assert tvm_output[0].shape == tflite_output[0].shape - assert tvm_output[1].shape == tflite_output[1].shape - assert tvm_output[2].shape == tflite_output[2].shape valid_count = tvm_output[3][0] - # only check the valid detections are the same - # tvm has a different convention to tflite for invalid detections, it uses all -1s whereas - # tflite appears to put in nonsense data instead - tvm_boxes = tvm_output[0][0][:valid_count] - tvm_classes = tvm_output[1][0][:valid_count] - tvm_scores = tvm_output[2][0][:valid_count] - # check the output data is correct - tvm.testing.assert_allclose(np.squeeze(tvm_boxes), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) - tvm.testing.assert_allclose(np.squeeze(tvm_classes), np.squeeze(tflite_output[1]), rtol=1e-5, atol=1e-5) - tvm.testing.assert_allclose(np.squeeze(tvm_scores), np.squeeze(tflite_output[2]), rtol=1e-5, atol=1e-5) + + # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare + # tflite and tvm tensors for only valid boxes. + for i in range(0, valid_count): + # Check bounding box co-ords + tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]), + rtol=1e-5, atol=1e-5) + # Check the class + tvm.testing.assert_allclose(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i]), + rtol=1e-5, atol=1e-5) + # Check the score + tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]), + rtol=1e-5, atol=1e-5) ####################################################################### @@ -1941,7 +1945,6 @@ def test_forward_qnn_mobilenet_v3_net(): tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) - ####################################################################### # SSD Mobilenet # ------------- @@ -1954,15 +1957,33 @@ def test_forward_coco_ssd_mobilenet_v1(): with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() - + np.random.seed(0) data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32') tflite_output = run_tflite_graph(tflite_model_buf, data) - tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=2) - for i in range(2): - tvm.testing.assert_allclose(np.squeeze(tvm_output[i]), np.squeeze(tflite_output[i]), - rtol=1e-5, atol=2e-5) ->>>>>>> TFlite e2e FP32 Object detection model + tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=4) + + # Check all output shapes are equal + assert all([tvm_tensor.shape == tflite_tensor.shape \ + for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)]) + + # Check valid count is the same + assert tvm_output[3] == tflite_output[3] + valid_count = tvm_output[3][0] + + # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare + # tflite and tvm tensors for only valid boxes. + for i in range(0, valid_count): + # Check bounding box co-ords + tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]), + rtol=1e-5, atol=1e-5) + # Check the class + tvm.testing.assert_allclose(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i]), + rtol=1e-5, atol=1e-5) + # Check the score + tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]), + rtol=1e-5, atol=1e-5) +>>>>>>> Fix test ####################################################################### From 96841697f320e087342a05cde8409c8aeeb8844a Mon Sep 17 00:00:00 2001 From: anijain2305 Date: Thu, 30 Apr 2020 00:45:06 +0000 Subject: [PATCH 3/8] [Relay-TFLite] Quantized activations --- python/tvm/relay/frontend/tflite.py | 165 +++++++++++++------ python/tvm/relay/testing/tf.py | 5 + tests/python/frontend/tflite/test_forward.py | 45 ++++- 3 files changed, 162 insertions(+), 53 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 703ef9c8b6b0..517ec37744f8 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -320,6 +320,40 @@ def dequantize(self, expr, tensor): input_zero_point=tensor.qnn_params['zero_point']) return dequantized + + def convert_qnn_fused_activation_function(self, expr, fused_activation_fn, + scale, zero_point, dtype): + """Convert TFLite fused activation function. The expr is an input quantized tensor with + scale and zero point """ + try: + from tflite.ActivationFunctionType import ActivationFunctionType + except ImportError: + raise ImportError("The tflite package must be installed") + + # Quantize a float value to an integer + quantize = lambda value : (value / scale) + zero_point + + # The input expr is a quantized tensor with its scale and zero point. We calculate the + # suitable clip off points based on these scale and zero point. + if fused_activation_fn == ActivationFunctionType.NONE: + return expr + elif fused_activation_fn == ActivationFunctionType.RELU6: + return _op.clip(expr, + a_min=quantize(0), + a_max=quantize(6)) + elif fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: + return _op.clip(expr, + a_min=quantize(-1), + a_max=quantize(1)) + elif fused_activation_fn == ActivationFunctionType.RELU: + return _op.clip(expr, + a_min=quantize(0), + a_max=float(tvm.tir.op.min_value(dtype).value)) + + fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] + raise tvm.error.OpNotImplemented( + 'Quantized activation {} is not supported for frontend TFLite.'.format(fused_activation_fn_str)) + def convert_conv2d(self, op): """Convert TFLite conv2d""" return self.convert_conv(op, "conv2d") @@ -453,17 +487,16 @@ def convert_l2_normalization(self, op): if self.is_quantized(op): raise tvm.error.OpNotImplemented( 'TFLite quantized L2_NORMALIZATION operator is not supported yet.') + # TFL uses only the default epsilon value out = _op.nn.l2_normalize(in_expr, eps=1e-12, axis=[input_tensor_rank - 1]) # if we have fused activation fn - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'TFLite quantized L2_NORMALIZATION operator\ - with fused activation function is not supported yet.') + if output_tensor.qnn_params: + raise tvm.error.OpNotImplemented( + 'TFLite quantized L2_NORMALIZATION operator is not supported yet.') + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) return out @@ -640,14 +673,20 @@ def convert_concatenation(self, op): output_zero_point=output_tensor.qnn_params['zero_point'], axis=concatenation_axis) - # if we have activation fn - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.concatenate')) + # Handle fused activations + if output_tensor.qnn_params: + scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=scale_val, + zero_point=zero_point_val, + dtype=output_tensor_type_str) + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) + return out def _convert_unary_elemwise(self, relay_op, op): @@ -855,13 +894,20 @@ def _convert_elemwise(self, relay_op, op): op_options = op.BuiltinOptions() options.Init(op_options.Bytes, op_options.Pos) fused_activation_fn = options.FusedActivationFunction() - # if we have activation fn - if fused_activation_fn != ActivationFunctionType.NONE: - if output_tensor.qnn_params: - raise tvm.error.OpNotImplemented( - 'Elemwise operators with fused activation are not supported yet.') - out = self.convert_fused_activation_function(out, fused_activation_fn) + # Handle fused activations + if output_tensor.qnn_params: + scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=scale_val, + zero_point=zero_point_val, + dtype=output_tensor_type_str) + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) return out def convert_add(self, op): @@ -1376,15 +1422,6 @@ def convert_fully_connected(self, op): dtype=bias_tensor_type_str) out = _op.nn.bias_add(out, bias_expr) - # If we have fused activations - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.dense')) - # Finally if the dense is quantized. Add a requantize at the end. if output_tensor.qnn_params: data_scale = input_tensor.qnn_params['scale'] @@ -1394,12 +1431,24 @@ def convert_fully_connected(self, op): new_input_scale_val = data_scale_val * weight_scale_val new_input_scale = relay.const(new_input_scale_val, 'float32') new_input_zero_point = relay.const(0, 'int32') + + # Call activation function + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=new_input_scale_val, + zero_point=0, + dtype='int32') + + # Requantize out = _qnn.op.requantize(out, input_scale=new_input_scale, input_zero_point=new_input_zero_point, output_scale=output_tensor.qnn_params['scale'], output_zero_point=output_tensor.qnn_params['zero_point'], out_dtype=output_tensor_type_str) + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) return out @@ -1435,18 +1484,20 @@ def convert_fused_activation_function(self, in_expr, fused_activation_fn): from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") - assert fused_activation_fn != ActivationFunctionType.NONE - if fused_activation_fn == ActivationFunctionType.RELU6: + + if fused_activation_fn == ActivationFunctionType.NONE: + return in_expr + elif fused_activation_fn == ActivationFunctionType.RELU6: return _op.clip(in_expr, a_min=0, a_max=6) - if fused_activation_fn == ActivationFunctionType.RELU: + elif fused_activation_fn == ActivationFunctionType.RELU: return _op.nn.relu(in_expr) - if fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: + elif fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: return _op.clip(in_expr, a_min=-1, a_max=1) - if fused_activation_fn == ActivationFunctionType.TANH: + elif fused_activation_fn == ActivationFunctionType.TANH: return _op.tanh(in_expr) fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] raise tvm.error.OpNotImplemented( - 'Operator {} is not supported for frontend TFLite.'.format(fused_activation_fn_str)) + 'Fused activation {} is not supported for frontend TFLite.'.format(fused_activation_fn_str)) def convert_conv(self, op, conv_type): """convolution implementation.""" @@ -1583,17 +1634,9 @@ def convert_conv(self, op, conv_type): channel_axis = 3 out = _op.nn.bias_add(out, bias_expr, axis=channel_axis) - # If we have fused activations - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.conv2d')) - - # Finally if the conv is quantized. Add a requantize at the end. + # Handle fused activation. if output_tensor.qnn_params: + # Calculate the intermediate scale and zero point of the int32 output. data_scale = input_tensor.qnn_params['scale'] weight_scale = weight_tensor.qnn_params['scale'] data_scale_val = get_scalar_from_constant(data_scale) @@ -1601,12 +1644,24 @@ def convert_conv(self, op, conv_type): new_input_scale_val = data_scale_val * weight_scale_val new_input_scale = relay.const(new_input_scale_val, 'float32') new_input_zero_point = relay.const(0, 'int32') + + # Call activation function + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=new_input_scale_val, + zero_point=0, + dtype='int32') + + # Finally requantize out = _qnn.op.requantize(out, input_scale=new_input_scale, input_zero_point=new_input_zero_point, output_scale=output_tensor.qnn_params['scale'], output_zero_point=output_tensor.qnn_params['zero_point'], out_dtype=output_tensor_type_str) + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) return out @@ -1846,13 +1901,19 @@ def convert_pool2d(self, op, pool_type): raise tvm.error.OpNotImplemented( 'Operator {} is not supported for frontend TFLite.'.format(pool_type + ' pool')) - # If we have fused activations - if fused_activation_fn != ActivationFunctionType.NONE: - if input_tensor.qnn_params: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.pool2d')) + # Handle fused activations + if output_tensor.qnn_params: + scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=scale_val, + zero_point=zero_point_val, + dtype=output_tensor_type_str) + else: out = self.convert_fused_activation_function(out, fused_activation_fn) + return out def convert_pad(self, op): diff --git a/python/tvm/relay/testing/tf.py b/python/tvm/relay/testing/tf.py index 1a231eb1aaed..8cbab38f1994 100644 --- a/python/tvm/relay/testing/tf.py +++ b/python/tvm/relay/testing/tf.py @@ -184,10 +184,15 @@ def get_workload_official(model_url, model_sub_path): dir_path = os.path.dirname(model_path) import tarfile + import zipfile if model_path.endswith("tgz") or model_path.endswith("gz"): tar = tarfile.open(model_path) tar.extractall(path=dir_path) tar.close() + elif model_path.endswith("zip"): + zip_object = zipfile.ZipFile(model_path) + zip_object.extractall(path=dir_path) + zip_object.close() else: raise RuntimeError('Could not decompress the file: ' + model_path) return os.path.join(dir_path, model_sub_path) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 4c863d44dc0c..16bc8f5fbe05 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1945,6 +1945,49 @@ def test_forward_qnn_mobilenet_v3_net(): tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + +####################################################################### +# SSD Mobilenet +# ------------- + +def test_forward_qnn_coco_ssd_mobilenet_v1(): + """Test the quantized Coco SSD Mobilenet V1 TF Lite model.""" + pytest.skip("Unsupported op - use_regular_nms") + tflite_model_file = tf_testing.get_workload_official( + "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip", + "detect.tflite") + + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + + np.random.seed(0) + data = np.random.uniform(size=(1, 300, 300, 3)).astype('uint8') + tflite_output = run_tflite_graph(tflite_model_buf, data) + tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=4) + + # Check all output shapes are equal + assert all([tvm_tensor.shape == tflite_tensor.shape \ + for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)]) + + # Check valid count is the same + assert tvm_output[3] == tflite_output[3] + valid_count = tvm_output[3][0] + + # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare + # tflite and tvm tensors for only valid boxes. + for i in range(0, valid_count): + # Check bounding box co-ords + tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]), + rtol=1e-5, atol=1e-5) + # Check the class + tvm.testing.assert_allclose(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i]), + rtol=1e-5, atol=1e-5) + # Check the score + tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]), + rtol=1e-5, atol=1e-5) + + + ####################################################################### # SSD Mobilenet # ------------- @@ -1957,7 +2000,7 @@ def test_forward_coco_ssd_mobilenet_v1(): with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() - + np.random.seed(0) data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32') tflite_output = run_tflite_graph(tflite_model_buf, data) From 22ac6897a2e23f41479afa4e1e154d22584467bb Mon Sep 17 00:00:00 2001 From: anijain2305 Date: Sat, 2 May 2020 00:42:12 +0000 Subject: [PATCH 4/8] Flexbuffer parsing --- python/tvm/relay/frontend/tflite.py | 175 +++++------------- .../tvm/relay/frontend/tflite_flexbuffer.py | 154 +++++++++++++++ tests/python/frontend/tflite/test_forward.py | 40 ++-- 3 files changed, 225 insertions(+), 144 deletions(-) create mode 100644 python/tvm/relay/frontend/tflite_flexbuffer.py diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 517ec37744f8..34a0a4971ed0 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -31,6 +31,7 @@ from ... import nd as _nd from .common import ExprTable from .common import infer_shape as _infer_shape +from .tflite_flexbuffer import FlexBufferDecode __all__ = ['from_tflite'] @@ -330,8 +331,13 @@ def convert_qnn_fused_activation_function(self, expr, fused_activation_fn, except ImportError: raise ImportError("The tflite package must be installed") - # Quantize a float value to an integer - quantize = lambda value : (value / scale) + zero_point + # Quantize a float value to an quantized integer value + quantize = lambda x: float(int(round(x / scale)) + zero_point) + + # Get min/max of the output dtype. This will be used to ensure that clip a_min/a_max are not + # beyond the dtype range. + qmin = float(tvm.tir.op.min_value(dtype).value) + qmax = float(tvm.tir.op.max_value(dtype).value) # The input expr is a quantized tensor with its scale and zero point. We calculate the # suitable clip off points based on these scale and zero point. @@ -339,16 +345,16 @@ def convert_qnn_fused_activation_function(self, expr, fused_activation_fn, return expr elif fused_activation_fn == ActivationFunctionType.RELU6: return _op.clip(expr, - a_min=quantize(0), - a_max=quantize(6)) + a_min=max(qmin, quantize(0)), + a_max=min(qmax, quantize(6.0))) elif fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: return _op.clip(expr, - a_min=quantize(-1), - a_max=quantize(1)) + a_min=max(qmin, quantize(-1.0)), + a_max=min(qmax, quantize(1.0))) elif fused_activation_fn == ActivationFunctionType.RELU: return _op.clip(expr, - a_min=quantize(0), - a_max=float(tvm.tir.op.min_value(dtype).value)) + a_min=max(qmin, quantize(0.0)), + a_max=qmax) fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] raise tvm.error.OpNotImplemented( @@ -1432,14 +1438,6 @@ def convert_fully_connected(self, op): new_input_scale = relay.const(new_input_scale_val, 'float32') new_input_zero_point = relay.const(0, 'int32') - # Call activation function - out = self.convert_qnn_fused_activation_function(\ - expr=out, - fused_activation_fn=fused_activation_fn, - scale=new_input_scale_val, - zero_point=0, - dtype='int32') - # Requantize out = _qnn.op.requantize(out, input_scale=new_input_scale, @@ -1447,6 +1445,17 @@ def convert_fully_connected(self, op): output_scale=output_tensor.qnn_params['scale'], output_zero_point=output_tensor.qnn_params['zero_point'], out_dtype=output_tensor_type_str) + + # Call activation function + output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=output_scale_val, + zero_point=output_zero_point_val, + dtype=output_tensor_type_str) + else: out = self.convert_fused_activation_function(out, fused_activation_fn) @@ -1645,14 +1654,6 @@ def convert_conv(self, op, conv_type): new_input_scale = relay.const(new_input_scale_val, 'float32') new_input_zero_point = relay.const(0, 'int32') - # Call activation function - out = self.convert_qnn_fused_activation_function(\ - expr=out, - fused_activation_fn=fused_activation_fn, - scale=new_input_scale_val, - zero_point=0, - dtype='int32') - # Finally requantize out = _qnn.op.requantize(out, input_scale=new_input_scale, @@ -1660,6 +1661,16 @@ def convert_conv(self, op, conv_type): output_scale=output_tensor.qnn_params['scale'], output_zero_point=output_tensor.qnn_params['zero_point'], out_dtype=output_tensor_type_str) + + # Call activation function + output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=output_scale_val, + zero_point=output_zero_point_val, + dtype=output_tensor_type_str) else: out = self.convert_fused_activation_function(out, fused_activation_fn) @@ -2302,28 +2313,15 @@ def convert_transpose_conv(self, op): def convert_detection_postprocess(self, op): """Convert TFLite_Detection_PostProcess""" - _option_names = [ - "w_scale", - "max_detections", - "_output_quantized", - "detections_per_class", - "x_scale", - "nms_score_threshold", - "num_classes", - "max_classes_per_detection", - "use_regular_nms", - "y_scale", - "h_scale", - "_support_output_type_float_in_quantized_op", - "nms_iou_threshold" - ] - - custom_options = get_custom_options(op, _option_names) - if custom_options["use_regular_nms"]: - raise tvm.error.OpAttributeUnImplemented( - "use_regular_nms=True is not yet supported for operator {}." - .format("TFLite_Detection_PostProcess") - ) + flexbuffer = op.CustomOptionsAsNumpy().tobytes() + custom_options = FlexBufferDecode(flexbuffer).decode() + + if "use_regular_nms" in custom_options: + if custom_options["use_regular_nms"]: + raise tvm.error.OpAttributeUnImplemented( + "use_regular_nms=True is not yet supported for operator {}." + .format("TFLite_Detection_PostProcess") + ) inputs = self.get_input_tensors(op) assert len(inputs) == 3, "inputs length should be 3" @@ -2494,91 +2492,6 @@ def get_tensor_name(subgraph, tensor_idx): return subgraph.Tensors(tensor_idx).Name().decode("utf-8") -def get_custom_options(op, option_names): - """Get the options of a custom operator. - - This implements partial flexbuffer deserialization to be able - to read custom options. It is not intended to be a general - purpose flexbuffer deserializer and as such only supports a - limited number of types and assumes the data is a flat map. - - Parameters - ---------- - op: - A custom TFlite operator. - option_names: list - A complete list of the custom option names. - - Returns - ------- - options: dict - A dictionary of the custom options. - - """ - import struct - from enum import IntEnum - - class _FlexBufferType(IntEnum): - """Flexbuffer type schema from flexbuffers.h""" - FBT_NULL = 0 - FBT_INT = 1 - FBT_UINT = 2 - FBT_FLOAT = 3 - # Types above stored inline, types below store an offset. - FBT_KEY = 4 - FBT_STRING = 5 - FBT_INDIRECT_INT = 6 - FBT_INDIRECT_UINT = 7 - FBT_INDIRECT_FLOAT = 8 - FBT_MAP = 9 - FBT_VECTOR = 10 # Untyped. - FBT_VECTOR_INT = 11 # Typed any size (stores no type table). - FBT_VECTOR_UINT = 12 - FBT_VECTOR_FLOAT = 13 - FBT_VECTOR_KEY = 14 - FBT_VECTOR_STRING = 15 - FBT_VECTOR_INT2 = 16 # Typed tuple (no type table, no size field). - FBT_VECTOR_UINT2 = 17 - FBT_VECTOR_FLOAT2 = 18 - FBT_VECTOR_INT3 = 19 # Typed triple (no type table, no size field). - FBT_VECTOR_UINT3 = 20 - FBT_VECTOR_FLOAT3 = 21 - FBT_VECTOR_INT4 = 22 # Typed quad (no type table, no size field). - FBT_VECTOR_UINT4 = 23 - FBT_VECTOR_FLOAT4 = 24 - FBT_BLOB = 25 - FBT_BOOL = 26 - FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type - - buffer = op.CustomOptionsAsNumpy().tobytes() - value_vector_offset = buffer[-3] - buffer = buffer[:-3] - num_bytes = 4 # Assume all values are stored in 32 bit width - value_vector_size = struct.unpack( - "> 2) - value_offset = -value_vector_offset + i*num_bytes - value_bytes = buffer[value_offset:value_offset+num_bytes] - if flex_type == _FlexBufferType.FBT_BOOL: - value = bool(value_bytes[0]) - if flex_type == _FlexBufferType.FBT_INT: - value = struct.unpack("> 2) + value_bytes = self.buffer[end + i * byte_width: end + (i + 1) * byte_width] + if value_type == FlexBufferType.FBT_BOOL: + value = bool(value_bytes[0]) + elif value_type == FlexBufferType.FBT_INT: + value = struct.unpack("> 2); + byte_width = 1 << BitWidth(root_packed_type & 3); + + if root_type == FlexBufferType.FBT_MAP: + return self.decode_map(root_end, byte_width, root_byte_width) + raise NotImplementedError("Flexbuffer Decoding is partially imlpemented.") diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 16bc8f5fbe05..220b0664e59b 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -73,6 +73,16 @@ def get_real_image(im_height, im_width): data = np.reshape(x, (1, im_height, im_width, 3)) return data +def get_real_image_object_detection(im_height, im_width): + repo_base = 'https://github.com/dmlc/web-data/raw/master/gluoncv/detection/' + img_name = 'street_small.jpg' + image_url = os.path.join(repo_base, img_name) + img_path = download_testdata(image_url, img_name, module='data') + image = Image.open(img_path).resize((im_height, im_width)) + x = np.array(image).astype('uint8') + data = np.reshape(x, (1, im_height, im_width, 3)) + return data + def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm', out_names=None): """ Generic function to compile on relay and execute on tvm """ @@ -98,6 +108,7 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict) + with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, target, params=params) @@ -1952,7 +1963,10 @@ def test_forward_qnn_mobilenet_v3_net(): def test_forward_qnn_coco_ssd_mobilenet_v1(): """Test the quantized Coco SSD Mobilenet V1 TF Lite model.""" - pytest.skip("Unsupported op - use_regular_nms") + pytest.skip("LLVM bug - getExtendedVectorNumElements - " + + "https://discuss.tvm.ai/t/segfault-in-llvm/3567. The workaround is to use a " + + "specific target, for example, llvm -mpcu=core-avx2") + tflite_model_file = tf_testing.get_workload_official( "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip", "detect.tflite") @@ -1960,8 +1974,7 @@ def test_forward_qnn_coco_ssd_mobilenet_v1(): with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() - np.random.seed(0) - data = np.random.uniform(size=(1, 300, 300, 3)).astype('uint8') + data = get_real_image_object_detection(300, 300) tflite_output = run_tflite_graph(tflite_model_buf, data) tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=4) @@ -1976,16 +1989,18 @@ def test_forward_qnn_coco_ssd_mobilenet_v1(): # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare # tflite and tvm tensors for only valid boxes. for i in range(0, valid_count): - # Check bounding box co-ords + # Check bounding box co-ords. The tolerances have to be adjusted because of differences between + # for requantiize operator in TFLite and TVM. tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]), - rtol=1e-5, atol=1e-5) + rtol=1e-1, atol=1e-1) + # Check the class - tvm.testing.assert_allclose(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i]), - rtol=1e-5, atol=1e-5) + # Stricter check to ensure class remains same + np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i])) + # Check the score tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]), - rtol=1e-5, atol=1e-5) - + rtol=1e-2, atol=1e-2) ####################################################################### @@ -2021,13 +2036,11 @@ def test_forward_coco_ssd_mobilenet_v1(): tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]), rtol=1e-5, atol=1e-5) # Check the class - tvm.testing.assert_allclose(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i]), - rtol=1e-5, atol=1e-5) + np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i])) + # Check the score tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]), rtol=1e-5, atol=1e-5) ->>>>>>> Fix test - ####################################################################### # MediaPipe @@ -2135,3 +2148,4 @@ def test_forward_mediapipe_hand_landmark(): #This also fails with a segmentation fault in my run #with Tflite 1.15.2 test_forward_qnn_mobilenet_v3_net() + test_forward_qnn_coco_ssd_mobilenet_v1() From 7e3e04d713d4c2262f10c9955d5757978012c91b Mon Sep 17 00:00:00 2001 From: anijain2305 Date: Sat, 2 May 2020 01:33:04 +0000 Subject: [PATCH 5/8] Lint --- python/tvm/relay/frontend/tflite.py | 31 +++++++------------ .../tvm/relay/frontend/tflite_flexbuffer.py | 10 +++--- 2 files changed, 16 insertions(+), 25 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 34a0a4971ed0..4021060bab0c 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -31,7 +31,7 @@ from ... import nd as _nd from .common import ExprTable from .common import infer_shape as _infer_shape -from .tflite_flexbuffer import FlexBufferDecode +from .tflite_flexbuffer import FlexBufferDecoder __all__ = ['from_tflite'] @@ -343,22 +343,22 @@ def convert_qnn_fused_activation_function(self, expr, fused_activation_fn, # suitable clip off points based on these scale and zero point. if fused_activation_fn == ActivationFunctionType.NONE: return expr - elif fused_activation_fn == ActivationFunctionType.RELU6: + if fused_activation_fn == ActivationFunctionType.RELU6: return _op.clip(expr, a_min=max(qmin, quantize(0)), a_max=min(qmax, quantize(6.0))) - elif fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: + if fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: return _op.clip(expr, a_min=max(qmin, quantize(-1.0)), a_max=min(qmax, quantize(1.0))) - elif fused_activation_fn == ActivationFunctionType.RELU: + if fused_activation_fn == ActivationFunctionType.RELU: return _op.clip(expr, a_min=max(qmin, quantize(0.0)), a_max=qmax) fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] raise tvm.error.OpNotImplemented( - 'Quantized activation {} is not supported for frontend TFLite.'.format(fused_activation_fn_str)) + 'Quantized activation {} is not supported yet.'.format(fused_activation_fn_str)) def convert_conv2d(self, op): """Convert TFLite conv2d""" @@ -468,7 +468,6 @@ def convert_l2_normalization(self, op): try: from tflite.BuiltinOptions import BuiltinOptions from tflite.L2NormOptions import L2NormOptions - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") @@ -501,8 +500,7 @@ def convert_l2_normalization(self, op): if output_tensor.qnn_params: raise tvm.error.OpNotImplemented( 'TFLite quantized L2_NORMALIZATION operator is not supported yet.') - else: - out = self.convert_fused_activation_function(out, fused_activation_fn) + out = self.convert_fused_activation_function(out, fused_activation_fn) return out @@ -647,7 +645,6 @@ def convert_concatenation(self, op): try: from tflite.ConcatenationOptions import ConcatenationOptions from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") @@ -835,7 +832,6 @@ def _convert_elemwise(self, relay_op, op): from tflite.MulOptions import MulOptions from tflite.DivOptions import DivOptions from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") @@ -1361,7 +1357,6 @@ def convert_fully_connected(self, op): from tflite.FullyConnectedOptions import FullyConnectedOptions from tflite.BuiltinOptions import BuiltinOptions from tflite.TensorType import TensorType - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") @@ -1496,23 +1491,22 @@ def convert_fused_activation_function(self, in_expr, fused_activation_fn): if fused_activation_fn == ActivationFunctionType.NONE: return in_expr - elif fused_activation_fn == ActivationFunctionType.RELU6: + if fused_activation_fn == ActivationFunctionType.RELU6: return _op.clip(in_expr, a_min=0, a_max=6) - elif fused_activation_fn == ActivationFunctionType.RELU: + if fused_activation_fn == ActivationFunctionType.RELU: return _op.nn.relu(in_expr) - elif fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: + if fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: return _op.clip(in_expr, a_min=-1, a_max=1) - elif fused_activation_fn == ActivationFunctionType.TANH: + if fused_activation_fn == ActivationFunctionType.TANH: return _op.tanh(in_expr) fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] raise tvm.error.OpNotImplemented( - 'Fused activation {} is not supported for frontend TFLite.'.format(fused_activation_fn_str)) + 'Fused activation {} is not supported yet.'.format(fused_activation_fn_str)) def convert_conv(self, op, conv_type): """convolution implementation.""" try: from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType from tflite.TensorType import TensorType from tflite.Conv2DOptions import Conv2DOptions from tflite.DepthwiseConv2DOptions import DepthwiseConv2DOptions @@ -1837,7 +1831,6 @@ def convert_pool2d(self, op, pool_type): """pool2d implementation.""" try: from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType from tflite.Pool2DOptions import Pool2DOptions from tflite.Padding import Padding except ImportError: @@ -2314,7 +2307,7 @@ def convert_transpose_conv(self, op): def convert_detection_postprocess(self, op): """Convert TFLite_Detection_PostProcess""" flexbuffer = op.CustomOptionsAsNumpy().tobytes() - custom_options = FlexBufferDecode(flexbuffer).decode() + custom_options = FlexBufferDecoder(flexbuffer).decode() if "use_regular_nms" in custom_options: if custom_options["use_regular_nms"]: diff --git a/python/tvm/relay/frontend/tflite_flexbuffer.py b/python/tvm/relay/frontend/tflite_flexbuffer.py index 6b8606af8889..e3427ab76e51 100644 --- a/python/tvm/relay/frontend/tflite_flexbuffer.py +++ b/python/tvm/relay/frontend/tflite_flexbuffer.py @@ -60,7 +60,7 @@ class FlexBufferType(IntEnum): FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type -class FlexBufferDecode(object): +class FlexBufferDecoder(object): """ This implements partial flexbuffer deserialization to be able to read custom options. It is not intended to be a general @@ -129,9 +129,6 @@ def decode_map(self, end, byte_width, parent_byte_width): # Find keys keys_offset = mid_loc - byte_width * 3 keys_end = self.indirect_jump(keys_offset, byte_width) - keys_byte_width = struct.unpack(\ - "> 2); - byte_width = 1 << BitWidth(root_packed_type & 3); + root_type = FlexBufferType(root_packed_type >> 2) + byte_width = 1 << BitWidth(root_packed_type & 3) if root_type == FlexBufferType.FBT_MAP: return self.decode_map(root_end, byte_width, root_byte_width) From a52bf0cda27205d9e43312414cdf5a56e0f298ad Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 2 May 2020 08:40:46 +0000 Subject: [PATCH 6/8] Relaxing checks. --- tests/python/frontend/tflite/test_forward.py | 35 ++++++++++++-------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 220b0664e59b..ad610e7f94d9 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1989,18 +1989,27 @@ def test_forward_qnn_coco_ssd_mobilenet_v1(): # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare # tflite and tvm tensors for only valid boxes. for i in range(0, valid_count): - # Check bounding box co-ords. The tolerances have to be adjusted because of differences between - # for requantiize operator in TFLite and TVM. - tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]), - rtol=1e-1, atol=1e-1) - - # Check the class - # Stricter check to ensure class remains same - np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i])) - - # Check the score - tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]), - rtol=1e-2, atol=1e-2) + # We compare the bounding boxes whose prediction score is above 60%. This is typical in end + # to end application where a low prediction score is discarded. This is also needed because + # multiple low score bounding boxes can have same score and TFlite and TVM can have + # different orderings for same score bounding boxes. Another reason for minor differences in + # low score bounding boxes is the difference between TVM and TFLite for requantize operator. + if tvm_output[2][0][i] > 0.6: + # Check bounding box co-ords. The tolerances have to be adjusted, from 1e-5 to 1e-2, + # because of differences between for requantiize operator in TFLite and TVM. + tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), + np.squeeze(tflite_output[0][0][i]), + rtol=1e-2, atol=1e-2) + + # Check the class + # Stricter check to ensure class remains same + np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), + np.squeeze(tflite_output[1][0][i])) + + # Check the score + tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), + np.squeeze(tflite_output[2][0][i]), + rtol=1e-5, atol=1e-5) ####################################################################### @@ -2008,7 +2017,7 @@ def test_forward_qnn_coco_ssd_mobilenet_v1(): # ------------- def test_forward_coco_ssd_mobilenet_v1(): - """Test the quantized Coco SSD Mobilenet V1 TF Lite model.""" + """Test the FP32 Coco SSD Mobilenet V1 TF Lite model.""" tflite_model_file = tf_testing.get_workload_official( "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tgz", "ssd_mobilenet_v1_coco_2018_01_28.tflite") From c516368ee60d8a04f161281e194d081108291bb0 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 4 May 2020 17:06:28 +0000 Subject: [PATCH 7/8] Github reviews --- python/tvm/relay/frontend/tflite_flexbuffer.py | 2 +- tests/python/frontend/tflite/test_forward.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/tflite_flexbuffer.py b/python/tvm/relay/frontend/tflite_flexbuffer.py index e3427ab76e51..d08570be2855 100644 --- a/python/tvm/relay/frontend/tflite_flexbuffer.py +++ b/python/tvm/relay/frontend/tflite_flexbuffer.py @@ -137,7 +137,7 @@ def decode_map(self, end, byte_width, parent_byte_width): return dict(zip(keys, values)) def decode(self): - """ Decode the buffer. Decoding is paritally implemented """ + """ Decode the buffer. Decoding is partially implemented """ root_end = len(self.buffer) - 1 root_byte_width = self.buffer[root_end] root_end -= 1 diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index ad610e7f94d9..ee4921262047 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1767,9 +1767,12 @@ def test_detection_postprocess(): # Check bounding box co-ords tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]), rtol=1e-5, atol=1e-5) + # Check the class - tvm.testing.assert_allclose(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i]), - rtol=1e-5, atol=1e-5) + # Stricter check to ensure class remains same + np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), + np.squeeze(tflite_output[1][0][i])) + # Check the score tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]), rtol=1e-5, atol=1e-5) @@ -1958,7 +1961,7 @@ def test_forward_qnn_mobilenet_v3_net(): ####################################################################### -# SSD Mobilenet +# Quantized SSD Mobilenet # ------------- def test_forward_qnn_coco_ssd_mobilenet_v1(): From 6aa7904a6a0f48fdb482478a11f24a90340f0085 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 6 May 2020 16:43:15 +0000 Subject: [PATCH 8/8] comments --- python/tvm/relay/testing/tf.py | 4 ++-- tests/python/frontend/tflite/test_forward.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/testing/tf.py b/python/tvm/relay/testing/tf.py index 8cbab38f1994..dc7937c0b346 100644 --- a/python/tvm/relay/testing/tf.py +++ b/python/tvm/relay/testing/tf.py @@ -183,13 +183,13 @@ def get_workload_official(model_url, model_sub_path): model_path = download_testdata(model_url, model_tar_name, module=['tf', 'official']) dir_path = os.path.dirname(model_path) - import tarfile - import zipfile if model_path.endswith("tgz") or model_path.endswith("gz"): + import tarfile tar = tarfile.open(model_path) tar.extractall(path=dir_path) tar.close() elif model_path.endswith("zip"): + import zipfile zip_object = zipfile.ZipFile(model_path) zip_object.extractall(path=dir_path) zip_object.close() diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index ee4921262047..42430fe2b6a4 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1962,7 +1962,7 @@ def test_forward_qnn_mobilenet_v3_net(): ####################################################################### # Quantized SSD Mobilenet -# ------------- +# ----------------------- def test_forward_qnn_coco_ssd_mobilenet_v1(): """Test the quantized Coco SSD Mobilenet V1 TF Lite model."""