From 776e24fce5db7b70b4c9e509d2bbcf7f58d590cc Mon Sep 17 00:00:00 2001 From: Ramana Radhakrishnan Date: Wed, 30 Jun 2021 00:30:12 +0100 Subject: [PATCH] Fix issue with importing models using Tensorflow Lite 2.4.x schema Tensorflow Lite has changed the opcode for BuiltinOperators to be represented as 32 bit integers instead of 8 bit integers in the schema. This is an attempt to fix this in a way that is clean to handle multiple versions of tensorflow lite in the frontend. --- python/tvm/relay/frontend/tflite.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 7e2173943265..a47fdf0141b5 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -251,7 +251,30 @@ def get_op_code_str(self, op): raise ImportError("The tflite package must be installed") op_code_list_idx = op.OpcodeIndex() - op_code_id = self.model.OperatorCodes(op_code_list_idx).BuiltinCode() + + op_c = self.model.OperatorCodes(op_code_list_idx) + # In TFlite 2.4.x there was a change where the type of the field that contained + # the builtin code changed from int8 to int32 in the flat buffer representation. + # However to retain support for old flat buffers that were created, they retained + # the original 8 bit encoding for the operator but in a new field accessed by the + # DeprecatedBuiltinCode method. + # This means that the API function BuiltinCode() is used on an operator + # which was originally encoded as an 8 bit quantity it would look for the + # code in the new int32 field in the schema and this creates the need + # for the check for the magic number of 127 which is indicated by + # BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES + # Remember however that this value came into existence only after Tensorflow + # lite 2.4.x and hence encase it in a try -except block. + # Phew ! + try: + if op_c.BuiltinCode() < BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES: + opc = op_c.DeprecatedBuiltinCode() + else: + opc = op_c.BuiltinCode() + except AttributeError: + opc = op_c.BuiltinCode() + + op_code_id = opc try: op_code_str = self.builtin_op_code[op_code_id] except KeyError: