diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 234beec244ba..aee4bb3e060e 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3804,18 +3804,16 @@ def _impl_v10(cls, inputs, attr, params): # # This function attempts to present 'x' in a form that meets both of those # requirements. - def try_resolve_to_const_scalar(x, dtype_override=None): + def try_resolve_to_const(x, dtype_override=None): x2 = try_resolve_var_to_const(x, params) - x3 = ensure_scalar_shape(x2) - + num_elem = np.prod(infer_shape(x)) + if num_elem == 1: + x2 = ensure_scalar_shape(x2) x_dtype = infer_type(x).checked_type.dtype if (dtype_override is not None) and (dtype_override != x_dtype): - x4 = _op.cast(x3, dtype_override) - else: - x4 = x3 - - x5 = fold_constant(x4) - return x5 + x2 = _op.cast(x2, dtype_override) + x3 = fold_constant(x2) + return x3 # Unpack the inputs and obtain some type info... a, a_scale, a_zp, b, b_scale, b_zp, y_scale, y_zp = inputs @@ -3855,14 +3853,14 @@ def try_resolve_to_const_scalar(x, dtype_override=None): ) # _qnn.op.dense requires the zero-point values to have dtype int32. - a_scale_scalar = try_resolve_to_const_scalar(a_scale) - a_zp_scalar = try_resolve_to_const_scalar(a_zp, "int32") + a_scale_scalar = try_resolve_to_const(a_scale) + a_zp_scalar = try_resolve_to_const(a_zp, "int32") - b_scale_scalar = try_resolve_to_const_scalar(b_scale) - b_zp_scalar = try_resolve_to_const_scalar(b_zp, "int32") + b_scale_scalar = try_resolve_to_const(b_scale) + b_zp_scalar = try_resolve_to_const(b_zp, "int32") - y_scale_scalar = try_resolve_to_const_scalar(y_scale) - y_zp_scalar = try_resolve_to_const_scalar(y_zp, "int32") + y_scale_scalar = try_resolve_to_const(y_scale) + y_zp_scalar = try_resolve_to_const(y_zp, "int32") # TODO: Confirm that we're using 'num_hidden_units' correctly / as intended with # the '_qnn.op.dense' instance below.