From bd5a180995797a41cbe9eaa5acabbf27f2905334 Mon Sep 17 00:00:00 2001 From: EunTaik Lee Date: Wed, 3 Feb 2021 18:12:21 +0900 Subject: [PATCH 1/9] [FRONTEND][TFLITE] get input tensor information from graph --- python/tvm/relay/frontend/tflite.py | 63 +++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index f474e59407e0..aa5c12ffed91 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3511,7 +3511,60 @@ def get_tensor_name(subgraph, tensor_idx): return subgraph.Tensors(tensor_idx).Name().decode("utf-8") -def from_tflite(model, shape_dict, dtype_dict): +def get_tensor_shape(subgraph, tensor_idx): + """Get the tensor shape. + + Parameters + ---------- + subgraph: + tflite.Subgraph.Subgraph + + tensor: + tensor index in subgraph + + Returns + ------- + tensor shape + """ + return tuple(subgraph.Tensors(tensor_idx).ShapeAsNumpy()) + + +def get_tensor_type(subgraph, tensor_idx): + """Get the tensor type. + + Parameters + ---------- + subgraph: + tflite.Subgraph.Subgraph + + tensor: + tensor index in subgraph + + Returns + ------- + tensor type + """ + from enum import Enum + + class TensorType(Enum): + FLOAT32 = 0 + FLOAT16 = 1 + INT32 = 2 + UINT8 = 3 + INT64 = 4 + STRING = 5 + BOOL = 6 + INT16 = 7 + COMPLEX64 = 8 + INT8 = 9 + FLOAT64 = 10 + COMPLEX128 = 11 + UINT64 = 12 + + return TensorType(subgraph.Tensors(tensor_idx).Type()).name.lower() + + +def from_tflite(model, shape_dict=None, dtype_dict=None): """Convert from tflite model into compatible relay Function. Parameters @@ -3560,8 +3613,12 @@ def from_tflite(model, shape_dict, dtype_dict): exp_tab = ExprTable() for model_input in model_inputs: model_input_name = get_tensor_name(subgraph, model_input) - shape = shape_dict[model_input_name] if model_input_name in shape_dict else None - dtype = dtype_dict[model_input_name] if model_input_name in dtype_dict else "float32" + try: + shape = get_tensor_shape(subgraph, model_input) + dtype = get_tensor_type(subgraph, model_input) + except: + shape = shape_dict[model_input_name] if model_input_name in shape_dict else None + dtype = dtype_dict[model_input_name] if model_input_name in dtype_dict else "float32" exp_tab.set_expr(model_input_name, _expr.var(model_input_name, shape=shape, dtype=dtype)) # op code in model From d4532d7b8817c62d1533f89bf5186797c62fdc62 Mon Sep 17 00:00:00 2001 From: EunTaik Lee Date: Wed, 3 Feb 2021 18:20:57 +0900 Subject: [PATCH 2/9] remove bare-except --- python/tvm/relay/frontend/tflite.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index aa5c12ffed91..de70724ef1ff 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3613,12 +3613,14 @@ def from_tflite(model, shape_dict=None, dtype_dict=None): exp_tab = ExprTable() for model_input in model_inputs: model_input_name = get_tensor_name(subgraph, model_input) - try: - shape = get_tensor_shape(subgraph, model_input) - dtype = get_tensor_type(subgraph, model_input) - except: + if shape_dict: shape = shape_dict[model_input_name] if model_input_name in shape_dict else None + else: + shape = get_tensor_shape(subgraph, model_input) + if dtype_dict: dtype = dtype_dict[model_input_name] if model_input_name in dtype_dict else "float32" + else: + dtype = get_tensor_type(subgraph, model_input) exp_tab.set_expr(model_input_name, _expr.var(model_input_name, shape=shape, dtype=dtype)) # op code in model From d58353ba5d5dd338a4fd41b204ce9264b0e134a2 Mon Sep 17 00:00:00 2001 From: EunTaik Lee Date: Wed, 3 Feb 2021 19:04:03 +0900 Subject: [PATCH 3/9] fix lint --- python/tvm/relay/frontend/tflite.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index de70724ef1ff..983dfa911032 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3547,6 +3547,7 @@ def get_tensor_type(subgraph, tensor_idx): from enum import Enum class TensorType(Enum): + """ Enum defined in tensorflow lite """ FLOAT32 = 0 FLOAT16 = 1 INT32 = 2 From e4621eca6c6c18a4b91c8dfa4fa94f3591ea5981 Mon Sep 17 00:00:00 2001 From: EunTaik Lee Date: Wed, 3 Feb 2021 19:25:44 +0900 Subject: [PATCH 4/9] delete empty line --- python/tvm/relay/frontend/tflite.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 983dfa911032..657f29e9703f 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3548,6 +3548,7 @@ def get_tensor_type(subgraph, tensor_idx): class TensorType(Enum): """ Enum defined in tensorflow lite """ + FLOAT32 = 0 FLOAT16 = 1 INT32 = 2 From 86c70b79df437cb7b269f3821e187eafc0e648fa Mon Sep 17 00:00:00 2001 From: EunTaik Lee Date: Thu, 4 Feb 2021 14:25:00 +0900 Subject: [PATCH 5/9] comment change --- python/tvm/relay/frontend/tflite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 657f29e9703f..ebbb39c20238 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3542,7 +3542,7 @@ def get_tensor_type(subgraph, tensor_idx): Returns ------- - tensor type + tensor type in string """ from enum import Enum From c954a95d57cb6197d7d693538d6a2c04ee17f51e Mon Sep 17 00:00:00 2001 From: EunTaik Lee Date: Wed, 10 Feb 2021 20:23:36 +0900 Subject: [PATCH 6/9] move some of the tflite frontend code from tvmc to tflite.py --- python/tvm/driver/tvmc/frontends.py | 33 +-------- python/tvm/relay/frontend/tflite.py | 102 +++++++++++----------------- 2 files changed, 42 insertions(+), 93 deletions(-) diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index bb54b82cceca..62249e2993c9 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -237,41 +237,10 @@ def load(self, path): if version != 3: raise TVMCException("input file not tflite version 3") - logger.debug("tflite_input_type") - shape_dict, dtype_dict = TFLiteFrontend._input_type(tflite_model) - logger.debug("parse TFLite model and convert into Relay computation graph") - mod, params = relay.frontend.from_tflite( - tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict - ) + mod, params = relay.frontend.from_tflite(tflite_model) return mod, params - @staticmethod - def _decode_type(n): - return TFLiteFrontend._tflite_m[n] - - @staticmethod - def _input_type(model): - subgraph_count = model.SubgraphsLength() - assert subgraph_count > 0 - shape_dict = {} - dtype_dict = {} - for subgraph_index in range(subgraph_count): - subgraph = model.Subgraphs(subgraph_index) - inputs_count = subgraph.InputsLength() - assert inputs_count >= 1 - for input_index in range(inputs_count): - input_ = subgraph.Inputs(input_index) - assert subgraph.TensorsLength() > input_ - tensor = subgraph.Tensors(input_) - input_shape = tuple(tensor.ShapeAsNumpy()) - tensor_type = tensor.Type() - input_name = tensor.Name().decode("utf8") - shape_dict[input_name] = input_shape - dtype_dict[input_name] = TFLiteFrontend._decode_type(tensor_type) - - return shape_dict, dtype_dict - class PyTorchFrontend(Frontend): """ PyTorch frontend for TVMC """ diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index ef2ddb01e580..eb9d4aa2158f 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3539,59 +3539,42 @@ def get_tensor_name(subgraph, tensor_idx): return subgraph.Tensors(tensor_idx).Name().decode("utf-8") -def get_tensor_shape(subgraph, tensor_idx): - """Get the tensor shape. - - Parameters - ---------- - subgraph: - tflite.Subgraph.Subgraph - - tensor: - tensor index in subgraph - - Returns - ------- - tensor shape - """ - return tuple(subgraph.Tensors(tensor_idx).ShapeAsNumpy()) - - -def get_tensor_type(subgraph, tensor_idx): - """Get the tensor type. - - Parameters - ---------- - subgraph: - tflite.Subgraph.Subgraph - - tensor: - tensor index in subgraph - - Returns - ------- - tensor type in string - """ - from enum import Enum - - class TensorType(Enum): - """ Enum defined in tensorflow lite """ - - FLOAT32 = 0 - FLOAT16 = 1 - INT32 = 2 - UINT8 = 3 - INT64 = 4 - STRING = 5 - BOOL = 6 - INT16 = 7 - COMPLEX64 = 8 - INT8 = 9 - FLOAT64 = 10 - COMPLEX128 = 11 - UINT64 = 12 - - return TensorType(subgraph.Tensors(tensor_idx).Type()).name.lower() +def _decode_type(n): + _tflite_m = { + 0: "float32", + 1: "float16", + 2: "int32", + 3: "uint8", + 4: "int64", + 5: "string", + 6: "bool", + 7: "int16", + 8: "complex64", + 9: "int8", + } + return _tflite_m[n] + + +def _input_type(model): + subgraph_count = model.SubgraphsLength() + assert subgraph_count > 0 + shape_dict = {} + dtype_dict = {} + for subgraph_index in range(subgraph_count): + subgraph = model.Subgraphs(subgraph_index) + inputs_count = subgraph.InputsLength() + assert inputs_count >= 1 + for input_index in range(inputs_count): + input_ = subgraph.Inputs(input_index) + assert subgraph.TensorsLength() > input_ + tensor = subgraph.Tensors(input_) + input_shape = tuple(tensor.ShapeAsNumpy()) + tensor_type = tensor.Type() + input_name = tensor.Name().decode("utf8") + shape_dict[input_name] = input_shape + dtype_dict[input_name] = _decode_type(tensor_type) + + return shape_dict, dtype_dict def from_tflite(model, shape_dict=None, dtype_dict=None): @@ -3632,6 +3615,9 @@ def from_tflite(model, shape_dict=None, dtype_dict=None): assert isinstance(model, tflite.Model.Model) + if not shape_dict or not dtype_dict: + shape_dict, dtype_dict = _input_type(model) + # keep the same as tflite assert model.SubgraphsLength() == 1, "only support one subgraph (main subgraph)" subgraph = model.Subgraphs(0) @@ -3643,14 +3629,8 @@ def from_tflite(model, shape_dict=None, dtype_dict=None): exp_tab = ExprTable() for model_input in model_inputs: model_input_name = get_tensor_name(subgraph, model_input) - if shape_dict: - shape = shape_dict[model_input_name] if model_input_name in shape_dict else None - else: - shape = get_tensor_shape(subgraph, model_input) - if dtype_dict: - dtype = dtype_dict[model_input_name] if model_input_name in dtype_dict else "float32" - else: - dtype = get_tensor_type(subgraph, model_input) + shape = shape_dict[model_input_name] if model_input_name in shape_dict else None + dtype = dtype_dict[model_input_name] if model_input_name in dtype_dict else "float32" exp_tab.set_expr(model_input_name, _expr.var(model_input_name, shape=shape, dtype=dtype)) # op code in model From 53e044aa88e4972f5847aa628b9ca4e41e4db716 Mon Sep 17 00:00:00 2001 From: EunTaik Lee Date: Wed, 10 Feb 2021 20:43:52 +0900 Subject: [PATCH 7/9] update shape and dtype when user provided them --- python/tvm/relay/frontend/tflite.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index eb9d4aa2158f..1b593ad8dea3 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3615,8 +3615,11 @@ def from_tflite(model, shape_dict=None, dtype_dict=None): assert isinstance(model, tflite.Model.Model) - if not shape_dict or not dtype_dict: - shape_dict, dtype_dict = _input_type(model) + _shape_dict, _dtype_dict = _input_type(model) + if shape_dict is not None: + _shape_dict.update(shape_dict) + if dtype_dict is not None: + _dtype_dict.update(dtype_dict) # keep the same as tflite assert model.SubgraphsLength() == 1, "only support one subgraph (main subgraph)" @@ -3629,8 +3632,8 @@ def from_tflite(model, shape_dict=None, dtype_dict=None): exp_tab = ExprTable() for model_input in model_inputs: model_input_name = get_tensor_name(subgraph, model_input) - shape = shape_dict[model_input_name] if model_input_name in shape_dict else None - dtype = dtype_dict[model_input_name] if model_input_name in dtype_dict else "float32" + shape = _shape_dict[model_input_name] if model_input_name in _shape_dict else None + dtype = _dtype_dict[model_input_name] if model_input_name in _dtype_dict else "float32" exp_tab.set_expr(model_input_name, _expr.var(model_input_name, shape=shape, dtype=dtype)) # op code in model From 2ff9b4364ab5f3ce663e9d4727ed90fd636c1b75 Mon Sep 17 00:00:00 2001 From: EunTaik Lee Date: Thu, 11 Feb 2021 00:44:55 +0900 Subject: [PATCH 8/9] remove unused var. pass user provided shape_dict --- python/tvm/driver/tvmc/frontends.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index caedf1f5d69f..4943269df9a6 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -198,19 +198,6 @@ def load(self, path, shape_dict=None): class TFLiteFrontend(Frontend): """ TFLite frontend for TVMC """ - _tflite_m = { - 0: "float32", - 1: "float16", - 2: "int32", - 3: "uint8", - 4: "int64", - 5: "string", - 6: "bool", - 7: "int16", - 8: "complex64", - 9: "int8", - } - @staticmethod def name(): return "tflite" @@ -243,6 +230,7 @@ def load(self, path, shape_dict=None): logger.debug("parse TFLite model and convert into Relay computation graph") mod, params = relay.frontend.from_tflite(tflite_model) + mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict) return mod, params From b80e798409122f0908809adac4dcb43999a1f900 Mon Sep 17 00:00:00 2001 From: EunTaik Lee Date: Thu, 11 Feb 2021 01:15:43 +0900 Subject: [PATCH 9/9] remove duplicate code --- python/tvm/driver/tvmc/frontends.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index 4943269df9a6..16e6c8eb966e 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -229,7 +229,6 @@ def load(self, path, shape_dict=None): raise TVMCException("input file not tflite version 3") logger.debug("parse TFLite model and convert into Relay computation graph") - mod, params = relay.frontend.from_tflite(tflite_model) mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict) return mod, params