From 6ecd6d9ca1778420f4e9c92050c6ee3b75ef718d Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Thu, 16 Apr 2020 13:24:44 +0100 Subject: [PATCH 1/5] [RELAY] Move frontend utils The util file currently under frontend is used from outside of frontend (in qnn/op/legalizations). This suggests that the file should be pushed up to a higher level. The benefit from this change is that importing qnn no longer also imports all the frontends. --- python/tvm/relay/frontend/tflite.py | 2 +- python/tvm/relay/qnn/op/legalizations.py | 2 +- python/tvm/relay/{frontend => }/util.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) rename python/tvm/relay/{frontend => }/util.py (98%) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index d489bd34f7ac..aba5a6286e50 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -28,7 +28,7 @@ from .. import op as _op from .. import qnn as _qnn from ... import nd as _nd -from .util import get_scalar_from_constant +from ..util import get_scalar_from_constant from .common import ExprTable from .common import infer_shape as _infer_shape diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index b1c19092b4c7..f9874b78467e 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -21,7 +21,7 @@ import tvm from tvm import relay from .. import op as reg -from ...frontend.util import get_scalar_from_constant +from ...util import get_scalar_from_constant ################################################# # Register the functions for different operators. diff --git a/python/tvm/relay/frontend/util.py b/python/tvm/relay/util.py similarity index 98% rename from python/tvm/relay/frontend/util.py rename to python/tvm/relay/util.py index a7f89a30b996..b207182e4113 100644 --- a/python/tvm/relay/frontend/util.py +++ b/python/tvm/relay/util.py @@ -18,7 +18,7 @@ """ Utility functions that are used across many directories. """ from __future__ import absolute_import import numpy as np -from .. import expr as _expr +from . import expr as _expr def get_scalar_from_constant(expr): """ Returns scalar value from Relay constant scalar. """ From 08136db919f8920a6d500623bf8b75fa1fa51b0a Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Fri, 24 Apr 2020 09:55:30 +0100 Subject: [PATCH 2/5] Inline get_scalar_from_constant Change-Id: I1cc64e9ecb0eadb6ac0f7b62e6ea174644af4ad4 --- python/tvm/relay/frontend/tflite.py | 15 ++++++++++++++- python/tvm/relay/qnn/op/legalizations.py | 13 ++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index aba5a6286e50..71fbaa0b09eb 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -28,7 +28,6 @@ from .. import op as _op from .. import qnn as _qnn from ... import nd as _nd -from ..util import get_scalar_from_constant from .common import ExprTable from .common import infer_shape as _infer_shape @@ -2219,6 +2218,20 @@ def get_expr(self, input_tensor_idx): def has_expr(self, input_tensor_idx): return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx)) + +def get_scalar_from_constant(expr): + """ Returns scalar value from Relay constant scalar. """ + assert isinstance(expr, _expr.Constant) and not expr.data.shape, \ + "Expr is not a constant scalar." + value = expr.data.asnumpy() + if value.dtype == np.dtype(np.int32): + return int(value) + if value.dtype == np.dtype(np.float32): + return float(value) + assert False, "Constant expr must be float32/int32" + return None # To suppress pylint + + def build_str_map(obj): """Build string map of TFLite enum int value diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index f9874b78467e..9669fa1b54b9 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -21,7 +21,6 @@ import tvm from tvm import relay from .. import op as reg -from ...util import get_scalar_from_constant ################################################# # Register the functions for different operators. @@ -54,6 +53,18 @@ def qnn_dense_legalize(attrs, inputs, types): # Helper functions. ################### +def get_scalar_from_constant(expr): + """ Returns scalar value from Relay constant scalar. """ + assert isinstance(expr, _expr.Constant) and not expr.data.shape, \ + "Expr is not a constant scalar." + value = expr.data.asnumpy() + if value.dtype == np.dtype(np.int32): + return int(value) + if value.dtype == np.dtype(np.float32): + return float(value) + assert False, "Constant expr must be float32/int32" + return None # To suppress pylint + # Helper function for lowering in the abscence of fast Int8 arithmetic units. def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): """ Converts QNN operators into a sequence of Relay operators that are friendly to HW that do From 7d890276af5f92d648f28262cfb4619c3115b783 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Fri, 24 Apr 2020 09:56:00 +0100 Subject: [PATCH 3/5] Remove util.py from Relay Change-Id: If9cd7cf3fc0bd1861a3a9b5604f338e084d8db96 --- python/tvm/relay/util.py | 33 --------------------------------- 1 file changed, 33 deletions(-) delete mode 100644 python/tvm/relay/util.py diff --git a/python/tvm/relay/util.py b/python/tvm/relay/util.py deleted file mode 100644 index b207182e4113..000000000000 --- a/python/tvm/relay/util.py +++ /dev/null @@ -1,33 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=wildcard-import, redefined-builtin, invalid-name -""" Utility functions that are used across many directories. """ -from __future__ import absolute_import -import numpy as np -from . import expr as _expr - -def get_scalar_from_constant(expr): - """ Returns scalar value from Relay constant scalar. """ - assert isinstance(expr, _expr.Constant) and not expr.data.shape, \ - "Expr is not a constant scalar." - value = expr.data.asnumpy() - if value.dtype == np.dtype(np.int32): - return int(value) - if value.dtype == np.dtype(np.float32): - return float(value) - assert False, "Constant expr must be float32/int32" - return None # To suppress pylint From 987cdbcb8a841f7721a85e8baf530b00f19d9afe Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Fri, 24 Apr 2020 16:32:09 +0100 Subject: [PATCH 4/5] Shorten functions Change-Id: Ieb537d82e6ee52421ff05a90cd00a03679ffebf2 --- python/tvm/relay/frontend/tflite.py | 8 ++------ python/tvm/relay/qnn/op/legalizations.py | 11 ++++------- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 71fbaa0b09eb..10fbda5313ae 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2224,12 +2224,8 @@ def get_scalar_from_constant(expr): assert isinstance(expr, _expr.Constant) and not expr.data.shape, \ "Expr is not a constant scalar." value = expr.data.asnumpy() - if value.dtype == np.dtype(np.int32): - return int(value) - if value.dtype == np.dtype(np.float32): - return float(value) - assert False, "Constant expr must be float32/int32" - return None # To suppress pylint + assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), "value must be float32/int32" + return np.asscalar(value) def build_str_map(obj): diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 9669fa1b54b9..707be7ddb9ac 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -20,6 +20,7 @@ import tvm from tvm import relay +import numpy as np from .. import op as reg ################################################# @@ -55,15 +56,11 @@ def qnn_dense_legalize(attrs, inputs, types): def get_scalar_from_constant(expr): """ Returns scalar value from Relay constant scalar. """ - assert isinstance(expr, _expr.Constant) and not expr.data.shape, \ + assert isinstance(expr, relay.Constant) and not expr.data.shape, \ "Expr is not a constant scalar." value = expr.data.asnumpy() - if value.dtype == np.dtype(np.int32): - return int(value) - if value.dtype == np.dtype(np.float32): - return float(value) - assert False, "Constant expr must be float32/int32" - return None # To suppress pylint + assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), "value must be float32/int32" + return np.asscalar(value) # Helper function for lowering in the abscence of fast Int8 arithmetic units. def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): From ce604b07c8e8f5ce23bf1696ec0767c87463e905 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Fri, 24 Apr 2020 16:41:07 +0100 Subject: [PATCH 5/5] Line length Change-Id: I1d216b7e73a060c4f118f5da50ce58b18eba907f --- python/tvm/relay/frontend/tflite.py | 3 ++- python/tvm/relay/qnn/op/legalizations.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 10fbda5313ae..65cf119538d6 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2224,7 +2224,8 @@ def get_scalar_from_constant(expr): assert isinstance(expr, _expr.Constant) and not expr.data.shape, \ "Expr is not a constant scalar." value = expr.data.asnumpy() - assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), "value must be float32/int32" + assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \ + "value must be float32/int32" return np.asscalar(value) diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 707be7ddb9ac..c96a730ee6ed 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -59,7 +59,8 @@ def get_scalar_from_constant(expr): assert isinstance(expr, relay.Constant) and not expr.data.shape, \ "Expr is not a constant scalar." value = expr.data.asnumpy() - assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), "value must be float32/int32" + assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \ + "value must be float32/int32" return np.asscalar(value) # Helper function for lowering in the abscence of fast Int8 arithmetic units.