diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index 53fbed66c8fc..16e6c8eb966e 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -198,19 +198,6 @@ def load(self, path, shape_dict=None): class TFLiteFrontend(Frontend): """ TFLite frontend for TVMC """ - _tflite_m = { - 0: "float32", - 1: "float16", - 2: "int32", - 3: "uint8", - 4: "int64", - 5: "string", - 6: "bool", - 7: "int16", - 8: "complex64", - 9: "int8", - } - @staticmethod def name(): return "tflite" @@ -241,43 +228,10 @@ def load(self, path, shape_dict=None): if version != 3: raise TVMCException("input file not tflite version 3") - logger.debug("tflite_input_type") - input_shapes, dtype_dict = TFLiteFrontend._input_type(tflite_model) - if shape_dict is not None: - input_shapes.update(shape_dict) - logger.debug("parse TFLite model and convert into Relay computation graph") - mod, params = relay.frontend.from_tflite( - tflite_model, shape_dict=input_shapes, dtype_dict=dtype_dict - ) + mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict) return mod, params - @staticmethod - def _decode_type(n): - return TFLiteFrontend._tflite_m[n] - - @staticmethod - def _input_type(model): - subgraph_count = model.SubgraphsLength() - assert subgraph_count > 0 - shape_dict = {} - dtype_dict = {} - for subgraph_index in range(subgraph_count): - subgraph = model.Subgraphs(subgraph_index) - inputs_count = subgraph.InputsLength() - assert inputs_count >= 1 - for input_index in range(inputs_count): - input_ = subgraph.Inputs(input_index) - assert subgraph.TensorsLength() > input_ - tensor = subgraph.Tensors(input_) - input_shape = tuple(tensor.ShapeAsNumpy()) - tensor_type = tensor.Type() - input_name = tensor.Name().decode("utf8") - shape_dict[input_name] = input_shape - dtype_dict[input_name] = TFLiteFrontend._decode_type(tensor_type) - - return shape_dict, dtype_dict - class PyTorchFrontend(Frontend): """ PyTorch frontend for TVMC """ diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 6d9bb18a7573..1b593ad8dea3 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3539,7 +3539,45 @@ def get_tensor_name(subgraph, tensor_idx): return subgraph.Tensors(tensor_idx).Name().decode("utf-8") -def from_tflite(model, shape_dict, dtype_dict): +def _decode_type(n): + _tflite_m = { + 0: "float32", + 1: "float16", + 2: "int32", + 3: "uint8", + 4: "int64", + 5: "string", + 6: "bool", + 7: "int16", + 8: "complex64", + 9: "int8", + } + return _tflite_m[n] + + +def _input_type(model): + subgraph_count = model.SubgraphsLength() + assert subgraph_count > 0 + shape_dict = {} + dtype_dict = {} + for subgraph_index in range(subgraph_count): + subgraph = model.Subgraphs(subgraph_index) + inputs_count = subgraph.InputsLength() + assert inputs_count >= 1 + for input_index in range(inputs_count): + input_ = subgraph.Inputs(input_index) + assert subgraph.TensorsLength() > input_ + tensor = subgraph.Tensors(input_) + input_shape = tuple(tensor.ShapeAsNumpy()) + tensor_type = tensor.Type() + input_name = tensor.Name().decode("utf8") + shape_dict[input_name] = input_shape + dtype_dict[input_name] = _decode_type(tensor_type) + + return shape_dict, dtype_dict + + +def from_tflite(model, shape_dict=None, dtype_dict=None): """Convert from tflite model into compatible relay Function. Parameters @@ -3577,6 +3615,12 @@ def from_tflite(model, shape_dict, dtype_dict): assert isinstance(model, tflite.Model.Model) + _shape_dict, _dtype_dict = _input_type(model) + if shape_dict is not None: + _shape_dict.update(shape_dict) + if dtype_dict is not None: + _dtype_dict.update(dtype_dict) + # keep the same as tflite assert model.SubgraphsLength() == 1, "only support one subgraph (main subgraph)" subgraph = model.Subgraphs(0) @@ -3588,8 +3632,8 @@ def from_tflite(model, shape_dict, dtype_dict): exp_tab = ExprTable() for model_input in model_inputs: model_input_name = get_tensor_name(subgraph, model_input) - shape = shape_dict[model_input_name] if model_input_name in shape_dict else None - dtype = dtype_dict[model_input_name] if model_input_name in dtype_dict else "float32" + shape = _shape_dict[model_input_name] if model_input_name in _shape_dict else None + dtype = _dtype_dict[model_input_name] if model_input_name in _dtype_dict else "float32" exp_tab.set_expr(model_input_name, _expr.var(model_input_name, shape=shape, dtype=dtype)) # op code in model