From e41a7173a40e497e9157385249c1d8d1c0359c21 Mon Sep 17 00:00:00 2001 From: yuan Date: Mon, 24 Jan 2022 21:34:21 +0100 Subject: [PATCH 1/4] fix after cr --- python/tvm/relay/frontend/onnx.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 234beec244ba..b4ae34a3ae64 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3804,10 +3804,13 @@ def _impl_v10(cls, inputs, attr, params): # # This function attempts to present 'x' in a form that meets both of those # requirements. - def try_resolve_to_const_scalar(x, dtype_override=None): + def try_resolve_to_const(x, dtype_override=None): x2 = try_resolve_var_to_const(x, params) - x3 = ensure_scalar_shape(x2) - + num_elem = np.prod(infer_shape(x)) + if num_elem == 1: + x3 = ensure_scalar_shape(x2) + else: + x3 = x2 x_dtype = infer_type(x).checked_type.dtype if (dtype_override is not None) and (dtype_override != x_dtype): x4 = _op.cast(x3, dtype_override) @@ -3855,14 +3858,14 @@ def try_resolve_to_const_scalar(x, dtype_override=None): ) # _qnn.op.dense requires the zero-point values to have dtype int32. - a_scale_scalar = try_resolve_to_const_scalar(a_scale) - a_zp_scalar = try_resolve_to_const_scalar(a_zp, "int32") + a_scale_scalar = try_resolve_to_const(a_scale) + a_zp_scalar = try_resolve_to_const(a_zp, "int32") - b_scale_scalar = try_resolve_to_const_scalar(b_scale) - b_zp_scalar = try_resolve_to_const_scalar(b_zp, "int32") + b_scale_scalar = try_resolve_to_const(b_scale) + b_zp_scalar = try_resolve_to_const(b_zp, "int32") - y_scale_scalar = try_resolve_to_const_scalar(y_scale) - y_zp_scalar = try_resolve_to_const_scalar(y_zp, "int32") + y_scale_scalar = try_resolve_to_const(y_scale) + y_zp_scalar = try_resolve_to_const(y_zp, "int32") # TODO: Confirm that we're using 'num_hidden_units' correctly / as intended with # the '_qnn.op.dense' instance below. From 246149f8477bee2755440aa418e9b494627b4aa0 Mon Sep 17 00:00:00 2001 From: yuan Date: Tue, 25 Jan 2022 07:31:25 +0100 Subject: [PATCH 2/4] fix after cr 2 --- python/tvm/relay/frontend/onnx.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index b4ae34a3ae64..aee4bb3e060e 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3808,17 +3808,12 @@ def try_resolve_to_const(x, dtype_override=None): x2 = try_resolve_var_to_const(x, params) num_elem = np.prod(infer_shape(x)) if num_elem == 1: - x3 = ensure_scalar_shape(x2) - else: - x3 = x2 + x2 = ensure_scalar_shape(x2) x_dtype = infer_type(x).checked_type.dtype if (dtype_override is not None) and (dtype_override != x_dtype): - x4 = _op.cast(x3, dtype_override) - else: - x4 = x3 - - x5 = fold_constant(x4) - return x5 + x2 = _op.cast(x2, dtype_override) + x3 = fold_constant(x2) + return x3 # Unpack the inputs and obtain some type info... a, a_scale, a_zp, b, b_scale, b_zp, y_scale, y_zp = inputs From dbd14b500e0f9b8336a1da3a1949ebba770ed5b0 Mon Sep 17 00:00:00 2001 From: yuan Date: Tue, 25 Jan 2022 19:47:31 +0100 Subject: [PATCH 3/4] emptycommit From 03fcb8e4c39fe42ad2be39d829ad389583cd2554 Mon Sep 17 00:00:00 2001 From: yuanfz-1 Date: Wed, 26 Jan 2022 10:53:51 +0100 Subject: [PATCH 4/4] emptycommit 2nd try