From 74c7c26d13929561f0443762742513e1e1c6edde Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 28 Jul 2021 22:25:32 +0000 Subject: [PATCH 01/10] conv2d working, fixing conv2d_depthwise --- python/tvm/relay/qnn/op/legalizations.py | 7 +- src/relay/qnn/op/convolution.cc | 97 +++++++++++++++++++++--- 2 files changed, 93 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 961517f863fb..deab179eee95 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -173,8 +173,11 @@ def _shift(data, zero_point, out_dtype): data_modified = relay.cast(data, "int32") data_modified = relay.add(data_modified, relay.const(shift, "int32")) data_modified = relay.cast(data_modified, out_dtype) - zero_point_val = get_scalar_from_constant(zero_point) - zero_point_modified = relay.const(zero_point_val + shift, "int32") + if isinstance(zero_point, relay.Constant): + zero_point_val = get_scalar_from_constant(zero_point) + zero_point_modified = relay.const(zero_point_val + shift, "int32") + else: + zero_point_modified = zero_point + relay.const(shift, "int32") return (data_modified, zero_point_modified) # Collect the dtypes. diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index a5161358865a..36f30c46fa81 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -65,7 +65,6 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, } } ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point - ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale // Kernel scale can be a vector of length output_channels or a scalar. if (param->groups == 1) { @@ -293,7 +292,11 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_ auto multiplied_t2 = reduced_t2; auto one_scalar = MakeConstantScalar(DataType::Int(32), 1); if (!IsEqualScalar(kernel_zero_point, one_scalar)) { - multiplied_t2 = Multiply(kernel_zero_point, reduced_t2); + if (!IsConstScalar(kernel_zero_point)) { + multiplied_t2 = Multiply(MakeRepeat(kernel_zero_point, channel_multiplier, 0), reduced_t2); + } else { + multiplied_t2 = Multiply(kernel_zero_point, reduced_t2); + } } // Reduce the C dimension. Find the dimension. @@ -378,6 +381,24 @@ Expr DepthwiseConv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_i return MakeConstantScalar(DataType::Int(32), scalar_term4); } +/* + * \brief Calculates the fourth term in the qnn.conv2d depthwise lowering sequence + for non-scalar kernel zero_point. + * \param input_zero_point_int The int value of input zero point. + * \param kernel_zero_point The expr for the kernel zero point. + * \param kernel_h The height of kernel. + * \param kernel_w The width of kernel. + * \return The sequence of Relay operators for term4. + * \note The term4 looks like this + * + * Sigma(r, s) zp_a * zp_w + */ +Expr DepthwiseConv2DFourthTerm(int input_zero_point_int, const Expr& kernel_zero_point, int kernel_h, + int kernel_w) { + Expr scalar_term4 = MakeConstantScalar(DataType::Int(32), input_zero_point_int * kernel_h * kernel_w); + return Multiply(scalar_term4, kernel_zero_point); +} + /* * \brief Calculates the first term in the qnn.conv2d lowering sequence. * \param data The input expr. @@ -456,7 +477,13 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, auto multiplied_t2 = reduced_t2; auto one_scalar = MakeConstantScalar(DataType::Int(32), 1); + std::cout << "Reduced T2: " << PrettyPrint(transform::InferType()(IRModule::FromExpr(reduced_t2))) << std::endl; if (!IsEqualScalar(kernel_zero_point, one_scalar)) { + if (!IsConstScalar(kernel_zero_point)) { + Layout layout(param->data_layout); + int channel_axis = layout.IndexOf(LayoutAxis::Get('C')); + reduced_t2 = MakeRepeat(reduced_t2, out_channels, channel_axis); + } multiplied_t2 = Multiply(kernel_zero_point, reduced_t2); } return multiplied_t2; @@ -531,6 +558,26 @@ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int i return MakeConstantScalar(DataType::Int(32), scalar_term4); } +/* + * \brief Calculates the fourth term in the qnn.conv2d lowering sequence + for non-scalar kernel zero_point. + * \param input_zero_point_int The int value of input zero point. + * \param kernel_zero_point The Expr for the kernel zero point. + * \param in_channels The number of input channels. + * \param kernel_h The height of kernel. + * \param kernel_w The width of kernel. + * \return The sequence of Relay operators for term4. + * \note The term4 looks like this + * + * Sigma(c,r,s) zp_a * zp_w + * + */ +Expr Conv2DFourthTerm(int input_zero_point_int, const Expr& kernel_zero_point, int in_channels, + int kernel_h, int kernel_w) { + Expr scalar_term4 = MakeConstantScalar(DataType::Int(32), input_zero_point_int * in_channels * kernel_h * kernel_w); + return Multiply(scalar_term4, kernel_zero_point); +} + /* * \brief Combines different terms of qnn conv2d lowering. * \param term1 The term1 of qnn conv2d lowering. @@ -658,7 +705,21 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, // Extract the integer zero points. auto input_zero_point_int = GetScalarFromConstant(input_zero_point); - auto kernel_zero_point_int = GetScalarFromConstant(kernel_zero_point); + + // Kernel zero point is allowed to be non-scalar. Let's check if that's the case. + bool dynamic_zp = false; + // Use -1 zero point as a default for dynamic. + int kernel_zero_point_int = -1; + if (IsConstScalar(kernel_zero_point)) { + kernel_zero_point_int = GetScalarFromConstant(kernel_zero_point); + } else { + // Figure out the channel axis. + Layout layout(param->data_layout); + int channel_axis = layout.IndexOf(LayoutAxis::Get('C')); + kernel_zero_point = Reshape(kernel_zero_point, {-1,}); + kernel_zero_point = ExpandBiasToMatchAxis(kernel_zero_point, 4, {channel_axis}); + dynamic_zp = true; + } // Fallback to int32 conv if there is dilation with non-zero kernel point or grouped conv2d // For dilated conv, if the kernel zero point is non-zero, the pooling operator also has to @@ -668,19 +729,32 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, ICHECK_EQ(param->dilation.size(), 2) << "qnn.conv2d only supports 2D dilation"; auto dilation_h = get_const_int(param->dilation[0]); auto dilation_w = get_const_int(param->dilation[1]); - if ((kernel_zero_point_int != 0 && (dilation_h != 1 || dilation_w != 1)) || - (param->groups != 1 && !is_depthwise(param))) { + // Check if qnn supports the conv2d parameters. If not, fallback to regular conv2d. + bool supported_dilation = (kernel_zero_point_int == 0) || (dilation_h == 1 && dilation_w == 1); + bool supported_groups = (param->groups == 1 || is_depthwise(param)); + bool conv2d_params_supported = supported_dilation && supported_groups; + if (!conv2d_params_supported) { + std::cout << "FALLBACK" << std::endl; return Conv2DFallBack(data, weight, input_zero_point, kernel_zero_point, param); } else if (is_depthwise(param)) { ICHECK_NE(channel_multiplier, -1); auto padded_data = Conv2DPadInput(data, input_zero_point, param); auto term1 = Conv2DFirstTerm(padded_data, weight, param); + std::cout << "term1" << PrettyPrint(transform::InferType()(IRModule::FromExpr(term1))) << std::endl; auto term2 = DepthwiseConv2DSecondTerm(padded_data, kernel_zero_point, param, kernel_h, kernel_w, channel_multiplier); + std::cout << "term2" << PrettyPrint(transform::InferType()(IRModule::FromExpr(term2))) << std::endl; auto term3 = DepthwiseConv2DThirdTerm(weight, input_zero_point, param, out_channels, channel_multiplier); - auto term4 = - DepthwiseConv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, kernel_h, kernel_w); + std::cout << "term3" << PrettyPrint(transform::InferType()(IRModule::FromExpr(term3))) << std::endl; + Expr term4; + if (dynamic_zp) { + term4 = DepthwiseConv2DFourthTerm(input_zero_point_int, kernel_zero_point, kernel_h, kernel_w); + } else { + term4 = DepthwiseConv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, kernel_h, + kernel_w); + } + std::cout << "term4" << PrettyPrint(transform::InferType()(IRModule::FromExpr(term4))) << std::endl; return Conv2DCombineTerms(term1, term2, term3, term4, input_zero_point_int, kernel_zero_point_int); } @@ -690,8 +764,13 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, auto term2 = Conv2DSecondTerm(padded_data, kernel_zero_point, param, kernel_h, kernel_w, out_channels); auto term3 = Conv2DThirdTerm(weight, input_zero_point, param, out_channels); - auto term4 = Conv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, in_channels, kernel_h, - kernel_w); + Expr term4; + if (dynamic_zp) { + term4 = Conv2DFourthTerm(input_zero_point_int, kernel_zero_point, in_channels, kernel_h, kernel_w); + } else { + term4 = Conv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, in_channels, kernel_h, + kernel_w); + } return Conv2DCombineTerms(term1, term2, term3, term4, input_zero_point_int, kernel_zero_point_int); } From 1b9812a643cb37630ae40d85c4dfdee1ba7393bd Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 28 Jul 2021 22:29:47 +0000 Subject: [PATCH 02/10] Depthwise conv2d working. --- src/relay/qnn/op/convolution.cc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 36f30c46fa81..aeed7999cd39 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -477,7 +477,6 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, auto multiplied_t2 = reduced_t2; auto one_scalar = MakeConstantScalar(DataType::Int(32), 1); - std::cout << "Reduced T2: " << PrettyPrint(transform::InferType()(IRModule::FromExpr(reduced_t2))) << std::endl; if (!IsEqualScalar(kernel_zero_point, one_scalar)) { if (!IsConstScalar(kernel_zero_point)) { Layout layout(param->data_layout); @@ -734,19 +733,15 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, bool supported_groups = (param->groups == 1 || is_depthwise(param)); bool conv2d_params_supported = supported_dilation && supported_groups; if (!conv2d_params_supported) { - std::cout << "FALLBACK" << std::endl; return Conv2DFallBack(data, weight, input_zero_point, kernel_zero_point, param); } else if (is_depthwise(param)) { ICHECK_NE(channel_multiplier, -1); auto padded_data = Conv2DPadInput(data, input_zero_point, param); auto term1 = Conv2DFirstTerm(padded_data, weight, param); - std::cout << "term1" << PrettyPrint(transform::InferType()(IRModule::FromExpr(term1))) << std::endl; auto term2 = DepthwiseConv2DSecondTerm(padded_data, kernel_zero_point, param, kernel_h, kernel_w, channel_multiplier); - std::cout << "term2" << PrettyPrint(transform::InferType()(IRModule::FromExpr(term2))) << std::endl; auto term3 = DepthwiseConv2DThirdTerm(weight, input_zero_point, param, out_channels, channel_multiplier); - std::cout << "term3" << PrettyPrint(transform::InferType()(IRModule::FromExpr(term3))) << std::endl; Expr term4; if (dynamic_zp) { term4 = DepthwiseConv2DFourthTerm(input_zero_point_int, kernel_zero_point, kernel_h, kernel_w); @@ -754,7 +749,6 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, term4 = DepthwiseConv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, kernel_h, kernel_w); } - std::cout << "term4" << PrettyPrint(transform::InferType()(IRModule::FromExpr(term4))) << std::endl; return Conv2DCombineTerms(term1, term2, term3, term4, input_zero_point_int, kernel_zero_point_int); } From 1ffa74c4d541ef4b11769984eef7bfdaa0e1c7f9 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 30 Jul 2021 23:35:03 +0000 Subject: [PATCH 03/10] Make convinteger work on cuda. --- python/tvm/relay/qnn/op/legalizations.py | 94 +++++++++--- src/relay/qnn/op/convolution.cc | 47 +++--- tests/python/frontend/onnx/test_forward.py | 161 ++++----------------- 3 files changed, 130 insertions(+), 172 deletions(-) diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index deab179eee95..76da216e5534 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -94,6 +94,25 @@ def get_scalar_from_constant(expr): return value.item(0) +def _shift(data, zero_point, out_dtype): + """Shifts (add/subtracts) the qnn tensor with +/-128)""" + if out_dtype == "uint8": + shift = 128 + elif out_dtype == "int8": + shift = -128 + else: + raise ValueError("Unsupported out dtype.") + data_modified = relay.cast(data, "int32") + data_modified = relay.add(data_modified, relay.const(shift, "int32")) + data_modified = relay.cast(data_modified, out_dtype) + if isinstance(zero_point, relay.Constant): + zero_point_val = get_scalar_from_constant(zero_point) + zero_point_modified = relay.const(zero_point_val + shift, "int32") + else: + zero_point_modified = zero_point + relay.const(shift, "int32") + return (data_modified, zero_point_modified) + + # Helper function for lowering in the abscence of fast Int8 arithmetic units. def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): """Converts QNN operators into a sequence of Relay operators that are friendly to HW that do @@ -161,25 +180,6 @@ def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op): result : tvm.relay.Expr The legalized expr """ - - def _shift(data, zero_point, out_dtype): - """Shifts (add/subtracts) the qnn tensor with +/-128)""" - if out_dtype == "uint8": - shift = 128 - elif out_dtype == "int8": - shift = -128 - else: - raise ValueError("Unsupported out dtype.") - data_modified = relay.cast(data, "int32") - data_modified = relay.add(data_modified, relay.const(shift, "int32")) - data_modified = relay.cast(data_modified, out_dtype) - if isinstance(zero_point, relay.Constant): - zero_point_val = get_scalar_from_constant(zero_point) - zero_point_modified = relay.const(zero_point_val + shift, "int32") - else: - zero_point_modified = zero_point + relay.const(shift, "int32") - return (data_modified, zero_point_modified) - # Collect the dtypes. data_dtype = types[0].dtype kernel_dtype = types[1].dtype @@ -208,6 +208,54 @@ def _shift(data, zero_point, out_dtype): ) +# Helper function to change dtypes to int8 x int8. Cuda dp4a instructions prefer this setting. +def helper_change_dtypes_to_int8(attrs, inputs, types, relay_op): + """Legalizes QNN conv2d/dense op for Nvidia HW. dp4a supports i8 x i8 fast conv/MM. If the dtypes + are already good, we dont transform. Else, we shift the tensor values and zero points to change + the dtype. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + # Collect the dtypes. + data_dtype = types[0].dtype + kernel_dtype = types[1].dtype + + # Collect the input exprs. + data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale = inputs + + # dp4a supports i8 x i8 fast conv/MM. Don't do anything if it is already satisfied. + if data_dtype == "int8" and kernel_dtype == "int8": + return None + + # Shift input if necessary. + if data_dtype == "uint8": + # Compute (QA + 128) and (zp_a + 128) + data, input_zero_point = _shift(data, input_zero_point, "int8") + + # Shift kernel if necessary. + if kernel_dtype == "uint8": + # Compute (QA - 128) and (zp_a - 128) + kernel, kernel_zero_point = _shift(kernel, kernel_zero_point, "int8") + + # Call qnn.conv2d with modified inputs and zero points. + new_attrs = {k: attrs[k] for k in attrs.keys()} + return relay_op( + data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale, **new_attrs + ) + + # Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting. def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op): """Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However, @@ -342,11 +390,11 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types): @qnn_conv2d_legalize.register("cuda") def _qnn_conv2d_legalize_cuda(attrs, inputs, types): - # CUDA prefers the dtypes to be same. - return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d) + # CUDA prefers both datatypes to be int8. + return helper_change_dtypes_to_int8(attrs, inputs, types, relay.qnn.op.conv2d) @qnn_dense_legalize.register("cuda") def _qnn_dense_legalize_cuda(attrs, inputs, types): - # CUDA prefers the dtypes to be same. - return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense) + # CUDA prefers both datatypes to be the int8. + return helper_change_dtypes_to_int8(attrs, inputs, types, relay.qnn.op.dense) diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index aeed7999cd39..e3c3972b7313 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -383,9 +383,9 @@ Expr DepthwiseConv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_i /* * \brief Calculates the fourth term in the qnn.conv2d depthwise lowering sequence - for non-scalar kernel zero_point. - * \param input_zero_point_int The int value of input zero point. - * \param kernel_zero_point The expr for the kernel zero point. + for non-constant zero_points. + * \param input_zero_point The Expr for the input zero point. + * \param kernel_zero_point The Expr for the kernel zero point. * \param kernel_h The height of kernel. * \param kernel_w The width of kernel. * \return The sequence of Relay operators for term4. @@ -393,10 +393,11 @@ Expr DepthwiseConv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_i * * Sigma(r, s) zp_a * zp_w */ -Expr DepthwiseConv2DFourthTerm(int input_zero_point_int, const Expr& kernel_zero_point, int kernel_h, +Expr DepthwiseConv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point, int kernel_h, int kernel_w) { - Expr scalar_term4 = MakeConstantScalar(DataType::Int(32), input_zero_point_int * kernel_h * kernel_w); - return Multiply(scalar_term4, kernel_zero_point); + Expr scalar_term4 = MakeConstantScalar(DataType::Int(32), kernel_h * kernel_w); + Expr variable_term4 = Multiply(input_zero_point, kernel_zero_point); + return Multiply(scalar_term4, variable_term4); } /* @@ -559,8 +560,8 @@ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int i /* * \brief Calculates the fourth term in the qnn.conv2d lowering sequence - for non-scalar kernel zero_point. - * \param input_zero_point_int The int value of input zero point. + for non-constant zero_points. + * \param input_zero_point The Expr for the input zero point. * \param kernel_zero_point The Expr for the kernel zero point. * \param in_channels The number of input channels. * \param kernel_h The height of kernel. @@ -571,10 +572,11 @@ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int i * Sigma(c,r,s) zp_a * zp_w * */ -Expr Conv2DFourthTerm(int input_zero_point_int, const Expr& kernel_zero_point, int in_channels, +Expr Conv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point, int in_channels, int kernel_h, int kernel_w) { - Expr scalar_term4 = MakeConstantScalar(DataType::Int(32), input_zero_point_int * in_channels * kernel_h * kernel_w); - return Multiply(scalar_term4, kernel_zero_point); + Expr scalar_term4 = MakeConstantScalar(DataType::Int(32), in_channels * kernel_h * kernel_w); + Expr variable_term4 = Multiply(input_zero_point, kernel_zero_point); + return Multiply(scalar_term4, variable_term4); } /* @@ -702,17 +704,26 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier) = GetWorkload(arg_types, param); - // Extract the integer zero points. - auto input_zero_point_int = GetScalarFromConstant(input_zero_point); - - // Kernel zero point is allowed to be non-scalar. Let's check if that's the case. + // zero points are allowed to be non-scalar. Let's check if that's the case. bool dynamic_zp = false; // Use -1 zero point as a default for dynamic. + int input_zero_point_int = -1; int kernel_zero_point_int = -1; + + // Input zero point can either be a constant or a scalar expression. + if (IsConstScalar(input_zero_point)) { + // Extract the integer zero points. + input_zero_point_int = GetScalarFromConstant(input_zero_point); + } else { + dynamic_zp = true; + } + + // Kernel zero point is allowed to be a constant or 1-D tensor. if (IsConstScalar(kernel_zero_point)) { + // Extract the integer zero points. kernel_zero_point_int = GetScalarFromConstant(kernel_zero_point); } else { - // Figure out the channel axis. + // Figure out the channel axis to force appropriate shape. Layout layout(param->data_layout); int channel_axis = layout.IndexOf(LayoutAxis::Get('C')); kernel_zero_point = Reshape(kernel_zero_point, {-1,}); @@ -744,7 +755,7 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, DepthwiseConv2DThirdTerm(weight, input_zero_point, param, out_channels, channel_multiplier); Expr term4; if (dynamic_zp) { - term4 = DepthwiseConv2DFourthTerm(input_zero_point_int, kernel_zero_point, kernel_h, kernel_w); + term4 = DepthwiseConv2DFourthTerm(input_zero_point, kernel_zero_point, kernel_h, kernel_w); } else { term4 = DepthwiseConv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, kernel_h, kernel_w); @@ -760,7 +771,7 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, auto term3 = Conv2DThirdTerm(weight, input_zero_point, param, out_channels); Expr term4; if (dynamic_zp) { - term4 = Conv2DFourthTerm(input_zero_point_int, kernel_zero_point, in_channels, kernel_h, kernel_w); + term4 = Conv2DFourthTerm(input_zero_point, kernel_zero_point, in_channels, kernel_h, kernel_w); } else { term4 = Conv2DFourthTerm(input_zero_point_int, kernel_zero_point_int, in_channels, kernel_h, kernel_w); diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 8b633c18977a..8d9522c38fa9 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -615,10 +615,7 @@ def test_dynamic_gather(): inputs=[], outputs=["indices"], value=onnx.helper.make_tensor( - name="const_indices", - data_type=onnx.TensorProto.INT64, - dims=[], - vals=[1], + name="const_indices", data_type=onnx.TensorProto.INT64, dims=[], vals=[1], ), ) y = helper.make_node("Gather", ["in", "indices"], ["out"], axis=axis) @@ -891,12 +888,7 @@ def test_slice(): x, x[:, 1::2], starts=(1,), ends=(9223372036854775807,), axes=(1,), steps=(2,) ) _test_slice_iteration_v10( - x, - x[0::1, 1::2], - starts=(0, 1), - ends=(4, 4), - axes=(0, 1), - steps=(1, 2), + x, x[0::1, 1::2], starts=(0, 1), ends=(4, 4), axes=(0, 1), steps=(1, 2), ) @@ -1279,10 +1271,7 @@ def verify_instance_norm(shape, axis=1): epsilon = 1e-5 node = onnx.helper.make_node( - "InstanceNormalization", - inputs=["x", "gamma", "beta"], - outputs=["y"], - epsilon=epsilon, + "InstanceNormalization", inputs=["x", "gamma", "beta"], outputs=["y"], epsilon=epsilon, ) graph = helper.make_graph( [node], @@ -1645,13 +1634,7 @@ def verify_pad(indata, pads, mode="constant", value=0.0): # onnx graph if mode in ["edge", "reflect"]: outdata = np.pad(indata, pad_width=np_pads, mode=mode) - node = helper.make_node( - "Pad", - inputs=["input"], - outputs=["output"], - mode=mode, - pads=pads, - ) + node = helper.make_node("Pad", inputs=["input"], outputs=["output"], mode=mode, pads=pads,) else: outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value) node = helper.make_node( @@ -1927,9 +1910,7 @@ def verify_unary_ops(op, x, rtol=1e-5, atol=1e-5, dtype="float32"): graph = helper.make_graph( [z], "_test", - inputs=[ - helper.make_tensor_value_info("in1", ONNX_DTYPE, list(in_shape)), - ], + inputs=[helper.make_tensor_value_info("in1", ONNX_DTYPE, list(in_shape)),], outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, list(out_shape))], ) model = helper.make_model(graph, producer_name="_test") @@ -2112,11 +2093,7 @@ def Sign_x(x): def verify_not(indata, dtype): x = indata.astype(dtype) - node = helper.make_node( - "Not", - inputs=["in"], - outputs=["out"], - ) + node = helper.make_node("Not", inputs=["in"], outputs=["out"],) graph = helper.make_graph( [node], @@ -2144,11 +2121,7 @@ def verify_and(indata, dtype): y = indata[1].astype(dtype) outdata = np.logical_and(x, y) - node = helper.make_node( - "And", - inputs=["in1", "in2"], - outputs=["out"], - ) + node = helper.make_node("And", inputs=["in1", "in2"], outputs=["out"],) graph = helper.make_graph( [node], @@ -2300,11 +2273,7 @@ def verify_or(indata, dtype): y = indata[1].astype(dtype) outdata = np.logical_or(x, y) - node = helper.make_node( - "Or", - inputs=["in1", "in2"], - outputs=["out"], - ) + node = helper.make_node("Or", inputs=["in1", "in2"], outputs=["out"],) graph = helper.make_graph( [node], @@ -3649,14 +3618,7 @@ def verify_opset_10(ishape, scales, mode): make_constant_node("scales", onnx.TensorProto.FLOAT, (len(scales),), scales), ] input_names = ["X", "scales"] - nodes.append( - helper.make_node( - "Resize", - inputs=input_names, - outputs=["Y"], - mode=mode, - ) - ) + nodes.append(helper.make_node("Resize", inputs=input_names, outputs=["Y"], mode=mode,)) oshape = [round(dim * scale) for (dim, scale) in zip(ishape, scales)] graph = helper.make_graph( @@ -3676,11 +3638,7 @@ def verify_opset_10(ishape, scales, mode): @tvm.testing.uses_gpu def test_nonzero(): def verify_nonzero(indata, outdata, dtype): - node = helper.make_node( - "NonZero", - inputs=["X"], - outputs=["Y"], - ) + node = helper.make_node("NonZero", inputs=["X"], outputs=["Y"],) graph = helper.make_graph( [node], @@ -3717,13 +3675,7 @@ def verify_topk(input_dims, K, axis=-1): "topk_test", inputs=[ helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)), - helper.make_tensor_value_info( - "K", - TensorProto.INT64, - [ - 1, - ], - ), + helper.make_tensor_value_info("K", TensorProto.INT64, [1,],), ], outputs=[ helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims), @@ -3776,13 +3728,7 @@ def verify_roi_align( inputs=[ helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)), helper.make_tensor_value_info("rois", TensorProto.FLOAT, [num_roi, 4]), - helper.make_tensor_value_info( - "batch_indicies", - TensorProto.INT64, - [ - num_roi, - ], - ), + helper.make_tensor_value_info("batch_indicies", TensorProto.INT64, [num_roi,],), ], outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, output_dims)], ) @@ -3835,10 +3781,7 @@ def verify_nms( ) inputs.append(score_threshold) node = helper.make_node( - "NonMaxSuppression", - inputs=input_names, - outputs=["Y"], - center_point_box=0, + "NonMaxSuppression", inputs=input_names, outputs=["Y"], center_point_box=0, ) graph = helper.make_graph( @@ -4152,9 +4095,7 @@ def append_constant_nodes(nodes, outputs, expected, name): if_graph = onnx.helper.make_graph( [if_node], "if_outer", - inputs=[ - onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), - ], + inputs=[onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []),], outputs=graph_outputs, ) @@ -4187,11 +4128,7 @@ def test_if(): @tvm.testing.uses_gpu def test_size(): def verify_size(indata): - node = helper.make_node( - "Size", - inputs=["X"], - outputs=["Y"], - ) + node = helper.make_node("Size", inputs=["X"], outputs=["Y"],) graph = helper.make_graph( [node], @@ -4282,11 +4219,7 @@ def verify_maxunpool(data, indices, kernel_shape, strides, output_shape=None, pa @tvm.testing.uses_gpu def test_softplus(): def verify_softplus(indata): - node = helper.make_node( - "Softplus", - inputs=["X"], - outputs=["Y"], - ) + node = helper.make_node("Softplus", inputs=["X"], outputs=["Y"],) graph = helper.make_graph( [node], @@ -4310,11 +4243,7 @@ def verify_softplus(indata): @tvm.testing.uses_gpu def test_cumsum(): def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): - cumsum_node = onnx.helper.make_node( - "CumSum", - inputs=["X", "axis"], - outputs=["Y"], - ) + cumsum_node = onnx.helper.make_node("CumSum", inputs=["X", "axis"], outputs=["Y"],) if exclusive != 0: exclusive_attr = helper.make_attribute("exclusive", exclusive) cumsum_node.attribute.append(exclusive_attr) @@ -4334,9 +4263,7 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): graph = helper.make_graph( nodes, "cumsum_test", - inputs=[ - helper.make_tensor_value_info("X", tensor_type, list(indata.shape)), - ], + inputs=[helper.make_tensor_value_info("X", tensor_type, list(indata.shape)),], outputs=[helper.make_tensor_value_info("Y", tensor_type, list(indata.shape))], ) @@ -4345,22 +4272,7 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): verify_with_ort_with_inputs(model, [indata], dtype=type, use_vm=True, opset=11) data = ( - np.array( - [ - 1.0, - 2.0, - 3.0, - 4.0, - 5.0, - 6.0, - 7.0, - 8.0, - 9.0, - 10.0, - 11.0, - 12.0, - ] - ) + np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,]) .astype(np.float32) .reshape((3, 4)) ) @@ -4387,11 +4299,7 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): @tvm.testing.uses_gpu def test_eyelike(): def verify_eyelike(indata): - node = helper.make_node( - "EyeLike", - inputs=["X"], - outputs=["Y"], - ) + node = helper.make_node("EyeLike", inputs=["X"], outputs=["Y"],) graph = helper.make_graph( [node], @@ -4751,11 +4659,7 @@ def test_onnx_nodes(test, target): def test_wrong_input(): - node = helper.make_node( - "Softplus", - inputs=["X"], - outputs=["Y"], - ) + node = helper.make_node("Softplus", inputs=["X"], outputs=["Y"],) graph = helper.make_graph( [node], @@ -4836,10 +4740,7 @@ def verify_reverse_sequence(x, sequence_lens, batch_axis, time_axis): @tvm.testing.uses_gpu def test_reverse_sequence(): - x = np.array( - [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]], - dtype=np.float32, - ) + x = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]], dtype=np.float32,) sequence_lens = np.array([1, 2, 3, 4], dtype=np.int64) verify_reverse_sequence(x, sequence_lens, 0, 1) @@ -5162,23 +5063,20 @@ def verify_convinteger( x_array = np.random.randint(low=0, high=255, size=x_shape).astype(dtype) w_array = np.random.uniform(low=0, high=255, size=w_shape).astype(dtype) - x_zero_point_array = np.random.randint(0, 255, size=[]).astype(dtype) - w_zero_point_array = np.random.randint(0, 255, size=[]).astype(dtype) + x_zero_point_array = np.random.randint(0, 255, size=[1]).astype(dtype) + w_zero_point_array = np.random.randint(0, 255, size=[1]).astype(dtype) ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] input_nodes = [ helper.make_tensor_value_info("x", ONNX_DTYPE, list(x_shape)), helper.make_tensor_value_info("w", ONNX_DTYPE, list(w_shape)), - helper.make_tensor_value_info("x_zero_point", ONNX_DTYPE, []), - helper.make_tensor_value_info("w_zero_point", ONNX_DTYPE, []), ] - input_names = [ - "x", - "w", - "x_zero_point", - "w_zero_point", + initializer = [ + helper.make_tensor("x_zero_point", ONNX_DTYPE, [], x_zero_point_array), + helper.make_tensor("w_zero_point", ONNX_DTYPE, [], w_zero_point_array), ] - input_values = [x_array, w_array, x_zero_point_array, w_zero_point_array] + input_names = ["x", "w", "x_zero_point", "w_zero_point"] + input_values = [x_array, w_array] if padding is None: ## autopadding with unset default attributes @@ -5213,6 +5111,7 @@ def verify_convinteger( [node], "convinteger_test", inputs=input_nodes, + initializer=initializer, outputs=[helper.make_tensor_value_info("y", TensorProto.INT32, list(y_shape))], ) model = helper.make_model(graph, producer_name="convinteger_test") From 9ea730e807cd6684ab18603805953b0222cc4f85 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 2 Aug 2021 15:27:56 +0000 Subject: [PATCH 04/10] Simplify code and add tests. --- src/relay/qnn/op/convolution.cc | 11 +--- tests/python/relay/test_op_qnn_conv2d.py | 76 ++++++++++++++++++++++-- 2 files changed, 74 insertions(+), 13 deletions(-) diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index e3c3972b7313..aa42283da2a4 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -711,19 +711,12 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, int kernel_zero_point_int = -1; // Input zero point can either be a constant or a scalar expression. - if (IsConstScalar(input_zero_point)) { + if (IsConstScalar(input_zero_point) && (IsConstScalar(kernel_zero_point))) { // Extract the integer zero points. input_zero_point_int = GetScalarFromConstant(input_zero_point); - } else { - dynamic_zp = true; - } - - // Kernel zero point is allowed to be a constant or 1-D tensor. - if (IsConstScalar(kernel_zero_point)) { - // Extract the integer zero points. kernel_zero_point_int = GetScalarFromConstant(kernel_zero_point); } else { - // Figure out the channel axis to force appropriate shape. + // Figure out the channel axis to force appropriate shape for kernel. Layout layout(param->data_layout); int channel_axis = layout.IndexOf(LayoutAxis::Get('C')); kernel_zero_point = Reshape(kernel_zero_point, {-1,}); diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index 3a81e6e7b47a..9cb09d84b485 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -49,10 +49,19 @@ def get_ref_func( groups, channels=None, ): + if isinstance(input_zero_point, (int, float)): + input_zero_point = relay.const(input_zero_point, "int32") + if isinstance(kernel_zero_point, (int, float)): + kernel_zero_point = relay.const(kernel_zero_point, "int32") + else: + # Kernel zero point expression requires manual broadcasting for OIHW. + if kernel_layout == "OIHW": + kernel_zero_point = relay.reshape(kernel_zero_point, [-1, 1, 1, 1]) + casted_data = relay.op.cast(data, "int32") casted_kernel = relay.op.cast(kernel, "int32") - shifted_data = relay.op.subtract(casted_data, relay.const(input_zero_point, "int32")) - shifted_kernel = relay.op.subtract(casted_kernel, relay.const(kernel_zero_point, "int32")) + shifted_data = relay.op.subtract(casted_data, input_zero_point) + shifted_kernel = relay.op.subtract(casted_kernel, kernel_zero_point) func = relay.op.nn.conv2d( shifted_data, shifted_kernel, @@ -88,11 +97,16 @@ def get_qnn_func( channels, groups, ): + if isinstance(input_zero_point, (int, float)): + input_zero_point = relay.const(input_zero_point, "int32") + if isinstance(kernel_zero_point, (int, float)): + kernel_zero_point = relay.const(kernel_zero_point, "int32") + func = relay.qnn.op.conv2d( data, kernel, - input_zero_point=relay.const(input_zero_point, "int32"), - kernel_zero_point=relay.const(kernel_zero_point, "int32"), + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, input_scale=relay.const(input_scale, "float32"), kernel_scale=relay.const(kernel_scale, "float32"), kernel_size=kernel_size, @@ -419,6 +433,60 @@ def test_both_zero_point(): verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) +def test_dynamic_zero_point(): + with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): + + # uint8 input with non static zero points. + data_shape = (2, 4, 2, 4) + data_dtype = "uint8" + kernel_shape = (3, 4, 2, 2) + kernel_dtype = "uint8" + input_zero_point = relay.op.multiply(relay.const(2, dtype='int32'), relay.const(2, dtype='int32')) + kernel_zero_point = relay.const(np.random.randint(10, size=[3]), 'int32') + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + # int8 input + data_shape = (2, 4, 2, 4) + data_dtype = "int8" + kernel_shape = (3, 4, 2, 2) + kernel_dtype = "int8" + ref_func, qnn_func = get_funcs( + data_shape=data_shape, + data_dtype=data_dtype, + kernel_shape=kernel_shape, + kernel_dtype=kernel_dtype, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, + input_scale=1.0, + kernel_scale=1.0, + kernel_size=(2, 2), + padding=(0, 0), + strides=(1, 1), + dilation=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="int32", + ) + verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) + + def test_layout(): with TempOpAttr("qnn.conv2d", "FTVMQnnLegalize", legalize_qnn_conv2d): From cee02bca7f6e345771651387f79826e371017db4 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 2 Aug 2021 15:29:39 +0000 Subject: [PATCH 05/10] Formatting. --- python/tvm/contrib/target/onnx.py | 2 +- src/relay/qnn/op/convolution.cc | 12 +- tests/python/frontend/onnx/test_forward.py | 143 +++++++++++++++++---- tests/python/relay/test_op_qnn_conv2d.py | 6 +- 4 files changed, 133 insertions(+), 30 deletions(-) diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py index b839af669fe6..6f8aab23cde1 100644 --- a/python/tvm/contrib/target/onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -655,7 +655,7 @@ def convert_attributes(cls, attrs): class Cast(OpConverter): - """ Operator converter for Cast.""" + """Operator converter for Cast.""" @classmethod def convert_attributes(cls, attrs): diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index aa42283da2a4..0b21eb0d8c77 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -382,7 +382,7 @@ Expr DepthwiseConv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_i } /* - * \brief Calculates the fourth term in the qnn.conv2d depthwise lowering sequence + * \brief Calculates the fourth term in the qnn.conv2d depthwise lowering sequence for non-constant zero_points. * \param input_zero_point The Expr for the input zero point. * \param kernel_zero_point The Expr for the kernel zero point. @@ -393,8 +393,8 @@ Expr DepthwiseConv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_i * * Sigma(r, s) zp_a * zp_w */ -Expr DepthwiseConv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point, int kernel_h, - int kernel_w) { +Expr DepthwiseConv2DFourthTerm(const Expr& input_zero_point, const Expr& kernel_zero_point, + int kernel_h, int kernel_w) { Expr scalar_term4 = MakeConstantScalar(DataType::Int(32), kernel_h * kernel_w); Expr variable_term4 = Multiply(input_zero_point, kernel_zero_point); return Multiply(scalar_term4, variable_term4); @@ -559,7 +559,7 @@ Expr Conv2DFourthTerm(int input_zero_point_int, int kernel_zero_point_int, int i } /* - * \brief Calculates the fourth term in the qnn.conv2d lowering sequence + * \brief Calculates the fourth term in the qnn.conv2d lowering sequence for non-constant zero_points. * \param input_zero_point The Expr for the input zero point. * \param kernel_zero_point The Expr for the kernel zero point. @@ -719,7 +719,9 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, // Figure out the channel axis to force appropriate shape for kernel. Layout layout(param->data_layout); int channel_axis = layout.IndexOf(LayoutAxis::Get('C')); - kernel_zero_point = Reshape(kernel_zero_point, {-1,}); + kernel_zero_point = Reshape(kernel_zero_point, { + -1, + }); kernel_zero_point = ExpandBiasToMatchAxis(kernel_zero_point, 4, {channel_axis}); dynamic_zp = true; } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 8d9522c38fa9..973a93ef95b5 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -615,7 +615,10 @@ def test_dynamic_gather(): inputs=[], outputs=["indices"], value=onnx.helper.make_tensor( - name="const_indices", data_type=onnx.TensorProto.INT64, dims=[], vals=[1], + name="const_indices", + data_type=onnx.TensorProto.INT64, + dims=[], + vals=[1], ), ) y = helper.make_node("Gather", ["in", "indices"], ["out"], axis=axis) @@ -888,7 +891,12 @@ def test_slice(): x, x[:, 1::2], starts=(1,), ends=(9223372036854775807,), axes=(1,), steps=(2,) ) _test_slice_iteration_v10( - x, x[0::1, 1::2], starts=(0, 1), ends=(4, 4), axes=(0, 1), steps=(1, 2), + x, + x[0::1, 1::2], + starts=(0, 1), + ends=(4, 4), + axes=(0, 1), + steps=(1, 2), ) @@ -1271,7 +1279,10 @@ def verify_instance_norm(shape, axis=1): epsilon = 1e-5 node = onnx.helper.make_node( - "InstanceNormalization", inputs=["x", "gamma", "beta"], outputs=["y"], epsilon=epsilon, + "InstanceNormalization", + inputs=["x", "gamma", "beta"], + outputs=["y"], + epsilon=epsilon, ) graph = helper.make_graph( [node], @@ -1634,7 +1645,13 @@ def verify_pad(indata, pads, mode="constant", value=0.0): # onnx graph if mode in ["edge", "reflect"]: outdata = np.pad(indata, pad_width=np_pads, mode=mode) - node = helper.make_node("Pad", inputs=["input"], outputs=["output"], mode=mode, pads=pads,) + node = helper.make_node( + "Pad", + inputs=["input"], + outputs=["output"], + mode=mode, + pads=pads, + ) else: outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value) node = helper.make_node( @@ -1910,7 +1927,9 @@ def verify_unary_ops(op, x, rtol=1e-5, atol=1e-5, dtype="float32"): graph = helper.make_graph( [z], "_test", - inputs=[helper.make_tensor_value_info("in1", ONNX_DTYPE, list(in_shape)),], + inputs=[ + helper.make_tensor_value_info("in1", ONNX_DTYPE, list(in_shape)), + ], outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, list(out_shape))], ) model = helper.make_model(graph, producer_name="_test") @@ -2093,7 +2112,11 @@ def Sign_x(x): def verify_not(indata, dtype): x = indata.astype(dtype) - node = helper.make_node("Not", inputs=["in"], outputs=["out"],) + node = helper.make_node( + "Not", + inputs=["in"], + outputs=["out"], + ) graph = helper.make_graph( [node], @@ -2121,7 +2144,11 @@ def verify_and(indata, dtype): y = indata[1].astype(dtype) outdata = np.logical_and(x, y) - node = helper.make_node("And", inputs=["in1", "in2"], outputs=["out"],) + node = helper.make_node( + "And", + inputs=["in1", "in2"], + outputs=["out"], + ) graph = helper.make_graph( [node], @@ -2273,7 +2300,11 @@ def verify_or(indata, dtype): y = indata[1].astype(dtype) outdata = np.logical_or(x, y) - node = helper.make_node("Or", inputs=["in1", "in2"], outputs=["out"],) + node = helper.make_node( + "Or", + inputs=["in1", "in2"], + outputs=["out"], + ) graph = helper.make_graph( [node], @@ -3618,7 +3649,14 @@ def verify_opset_10(ishape, scales, mode): make_constant_node("scales", onnx.TensorProto.FLOAT, (len(scales),), scales), ] input_names = ["X", "scales"] - nodes.append(helper.make_node("Resize", inputs=input_names, outputs=["Y"], mode=mode,)) + nodes.append( + helper.make_node( + "Resize", + inputs=input_names, + outputs=["Y"], + mode=mode, + ) + ) oshape = [round(dim * scale) for (dim, scale) in zip(ishape, scales)] graph = helper.make_graph( @@ -3638,7 +3676,11 @@ def verify_opset_10(ishape, scales, mode): @tvm.testing.uses_gpu def test_nonzero(): def verify_nonzero(indata, outdata, dtype): - node = helper.make_node("NonZero", inputs=["X"], outputs=["Y"],) + node = helper.make_node( + "NonZero", + inputs=["X"], + outputs=["Y"], + ) graph = helper.make_graph( [node], @@ -3675,7 +3717,13 @@ def verify_topk(input_dims, K, axis=-1): "topk_test", inputs=[ helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)), - helper.make_tensor_value_info("K", TensorProto.INT64, [1,],), + helper.make_tensor_value_info( + "K", + TensorProto.INT64, + [ + 1, + ], + ), ], outputs=[ helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims), @@ -3728,7 +3776,13 @@ def verify_roi_align( inputs=[ helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)), helper.make_tensor_value_info("rois", TensorProto.FLOAT, [num_roi, 4]), - helper.make_tensor_value_info("batch_indicies", TensorProto.INT64, [num_roi,],), + helper.make_tensor_value_info( + "batch_indicies", + TensorProto.INT64, + [ + num_roi, + ], + ), ], outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, output_dims)], ) @@ -3781,7 +3835,10 @@ def verify_nms( ) inputs.append(score_threshold) node = helper.make_node( - "NonMaxSuppression", inputs=input_names, outputs=["Y"], center_point_box=0, + "NonMaxSuppression", + inputs=input_names, + outputs=["Y"], + center_point_box=0, ) graph = helper.make_graph( @@ -4095,7 +4152,9 @@ def append_constant_nodes(nodes, outputs, expected, name): if_graph = onnx.helper.make_graph( [if_node], "if_outer", - inputs=[onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []),], + inputs=[ + onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []), + ], outputs=graph_outputs, ) @@ -4128,7 +4187,11 @@ def test_if(): @tvm.testing.uses_gpu def test_size(): def verify_size(indata): - node = helper.make_node("Size", inputs=["X"], outputs=["Y"],) + node = helper.make_node( + "Size", + inputs=["X"], + outputs=["Y"], + ) graph = helper.make_graph( [node], @@ -4219,7 +4282,11 @@ def verify_maxunpool(data, indices, kernel_shape, strides, output_shape=None, pa @tvm.testing.uses_gpu def test_softplus(): def verify_softplus(indata): - node = helper.make_node("Softplus", inputs=["X"], outputs=["Y"],) + node = helper.make_node( + "Softplus", + inputs=["X"], + outputs=["Y"], + ) graph = helper.make_graph( [node], @@ -4243,7 +4310,11 @@ def verify_softplus(indata): @tvm.testing.uses_gpu def test_cumsum(): def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): - cumsum_node = onnx.helper.make_node("CumSum", inputs=["X", "axis"], outputs=["Y"],) + cumsum_node = onnx.helper.make_node( + "CumSum", + inputs=["X", "axis"], + outputs=["Y"], + ) if exclusive != 0: exclusive_attr = helper.make_attribute("exclusive", exclusive) cumsum_node.attribute.append(exclusive_attr) @@ -4263,7 +4334,9 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): graph = helper.make_graph( nodes, "cumsum_test", - inputs=[helper.make_tensor_value_info("X", tensor_type, list(indata.shape)),], + inputs=[ + helper.make_tensor_value_info("X", tensor_type, list(indata.shape)), + ], outputs=[helper.make_tensor_value_info("Y", tensor_type, list(indata.shape))], ) @@ -4272,7 +4345,22 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): verify_with_ort_with_inputs(model, [indata], dtype=type, use_vm=True, opset=11) data = ( - np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,]) + np.array( + [ + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, + 9.0, + 10.0, + 11.0, + 12.0, + ] + ) .astype(np.float32) .reshape((3, 4)) ) @@ -4299,7 +4387,11 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): @tvm.testing.uses_gpu def test_eyelike(): def verify_eyelike(indata): - node = helper.make_node("EyeLike", inputs=["X"], outputs=["Y"],) + node = helper.make_node( + "EyeLike", + inputs=["X"], + outputs=["Y"], + ) graph = helper.make_graph( [node], @@ -4659,7 +4751,11 @@ def test_onnx_nodes(test, target): def test_wrong_input(): - node = helper.make_node("Softplus", inputs=["X"], outputs=["Y"],) + node = helper.make_node( + "Softplus", + inputs=["X"], + outputs=["Y"], + ) graph = helper.make_graph( [node], @@ -4740,7 +4836,10 @@ def verify_reverse_sequence(x, sequence_lens, batch_axis, time_axis): @tvm.testing.uses_gpu def test_reverse_sequence(): - x = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]], dtype=np.float32,) + x = np.array( + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]], + dtype=np.float32, + ) sequence_lens = np.array([1, 2, 3, 4], dtype=np.int64) verify_reverse_sequence(x, sequence_lens, 0, 1) diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index 9cb09d84b485..ffd1a29057c8 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -441,8 +441,10 @@ def test_dynamic_zero_point(): data_dtype = "uint8" kernel_shape = (3, 4, 2, 2) kernel_dtype = "uint8" - input_zero_point = relay.op.multiply(relay.const(2, dtype='int32'), relay.const(2, dtype='int32')) - kernel_zero_point = relay.const(np.random.randint(10, size=[3]), 'int32') + input_zero_point = relay.op.multiply( + relay.const(2, dtype="int32"), relay.const(2, dtype="int32") + ) + kernel_zero_point = relay.const(np.random.randint(10, size=[3]), "int32") ref_func, qnn_func = get_funcs( data_shape=data_shape, data_dtype=data_dtype, From 1a8ea0f1ad61baaf29b8d714e0a89e33edd81dc5 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 2 Aug 2021 17:41:30 +0000 Subject: [PATCH 06/10] Fixed fallback broadcasting. --- src/relay/qnn/op/convolution.cc | 24 +++++++++---- tests/python/relay/test_op_qnn_conv2d.py | 43 +++++++++++++----------- 2 files changed, 41 insertions(+), 26 deletions(-) diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 0b21eb0d8c77..2d88ebe2456e 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -716,13 +716,8 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, input_zero_point_int = GetScalarFromConstant(input_zero_point); kernel_zero_point_int = GetScalarFromConstant(kernel_zero_point); } else { - // Figure out the channel axis to force appropriate shape for kernel. - Layout layout(param->data_layout); - int channel_axis = layout.IndexOf(LayoutAxis::Get('C')); - kernel_zero_point = Reshape(kernel_zero_point, { - -1, - }); - kernel_zero_point = ExpandBiasToMatchAxis(kernel_zero_point, 4, {channel_axis}); + // Make kernel_zero_point expression a 1-D tensor for consistent shape. + kernel_zero_point = Reshape(kernel_zero_point, {-1,}); dynamic_zp = true; } @@ -738,6 +733,21 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, bool supported_dilation = (kernel_zero_point_int == 0) || (dilation_h == 1 && dilation_w == 1); bool supported_groups = (param->groups == 1 || is_depthwise(param)); bool conv2d_params_supported = supported_dilation && supported_groups; + + // If we need to fall back to default conv2d, kernel zp may need to be broadcast to kernel_layout. + // Otherwise, we broadcast it to data_layout for qnn lowering. + if (dynamic_zp) { + if (!conv2d_params_supported) { + Layout kernel_layout(param->kernel_layout); + int kernel_axis = kernel_layout.IndexOf(LayoutAxis::Get("O")); + kernel_zero_point = ExpandBiasToMatchAxis(kernel_zero_point, 4, {kernel_axis}); + } else { + Layout data_layout(param->data_layout); + int channel_axis = data_layout.IndexOf(LayoutAxis::Get("C")); + kernel_zero_point = ExpandBiasToMatchAxis(kernel_zero_point, 4, {channel_axis}); + } + } + if (!conv2d_params_supported) { return Conv2DFallBack(data, weight, input_zero_point, kernel_zero_point, param); } else if (is_depthwise(param)) { diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index ffd1a29057c8..2f6e17250acc 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -958,13 +958,17 @@ def test_depthwise_depth_multiplier(): data_dtype = "uint8" kernel_shape = (4, 1, 3, 3) kernel_dtype = "uint8" + input_zero_point = relay.op.multiply( + relay.const(2, dtype="int32"), relay.const(2, dtype="int32") + ) + kernel_zero_point = relay.const(np.random.randint(10, size=[4]), "int32") ref_func, qnn_func = get_funcs( data_shape=data_shape, data_dtype=data_dtype, kernel_shape=kernel_shape, kernel_dtype=kernel_dtype, - input_zero_point=5, - kernel_zero_point=3, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, input_scale=1.0, kernel_scale=1.0, kernel_size=(3, 3), @@ -975,6 +979,7 @@ def test_depthwise_depth_multiplier(): kernel_layout="OIHW", out_dtype="int32", groups=4, + channels=4, ) verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, kernel_dtype) @@ -989,8 +994,8 @@ def test_depthwise_depth_multiplier(): data_dtype=data_dtype, kernel_shape=kernel_shape, kernel_dtype=kernel_dtype, - input_zero_point=5, - kernel_zero_point=3, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, input_scale=1.0, kernel_scale=1.0, kernel_size=(3, 3), @@ -1090,19 +1095,19 @@ def test_per_channel_kernel_scale(): if __name__ == "__main__": - test_no_zero_point() - test_input_zero_point() - test_kernel_zero_point() - test_both_zero_point() - test_layout() - test_padding() - test_dilation() - test_const_folding() - test_kernel_size_1x1() - test_kernel_size_1x1_strides_2() - test_tflite_large_irregular() - test_broadcast_layout() - test_tflite_output_multiplier_greater_than_one() - test_tflite_anistropic_strides() + #test_no_zero_point() + #test_input_zero_point() + #test_kernel_zero_point() + #test_both_zero_point() + #test_layout() + #test_padding() + #test_dilation() + #test_const_folding() + #test_kernel_size_1x1() + #test_kernel_size_1x1_strides_2() + #test_tflite_large_irregular() + #test_broadcast_layout() + #test_tflite_output_multiplier_greater_than_one() + #test_tflite_anistropic_strides() test_depthwise_depth_multiplier() - test_per_channel_kernel_scale() + #test_per_channel_kernel_scale() From 5a89d719c4633307a511246efc9b9db08d27702b Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 2 Aug 2021 18:24:52 +0000 Subject: [PATCH 07/10] Fix fallback broadcasting. --- tests/python/relay/test_op_qnn_conv2d.py | 42 +++++++++++++----------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index 2f6e17250acc..3736350cbfe1 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -54,9 +54,11 @@ def get_ref_func( if isinstance(kernel_zero_point, (int, float)): kernel_zero_point = relay.const(kernel_zero_point, "int32") else: - # Kernel zero point expression requires manual broadcasting for OIHW. + # Kernel zero point expression requires manual broadcasting for some layouts. if kernel_layout == "OIHW": kernel_zero_point = relay.reshape(kernel_zero_point, [-1, 1, 1, 1]) + elif kernel_layout == "HWOI": + kernel_zero_point = relay.reshape(kernel_zero_point, [1, 1, -1, 1]) casted_data = relay.op.cast(data, "int32") casted_kernel = relay.op.cast(kernel, "int32") @@ -1021,8 +1023,8 @@ def test_depthwise_depth_multiplier(): data_dtype=data_dtype, kernel_shape=kernel_shape, kernel_dtype=kernel_dtype, - input_zero_point=5, - kernel_zero_point=3, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, input_scale=1.0, kernel_scale=1.0, kernel_size=(3, 3), @@ -1046,8 +1048,8 @@ def test_depthwise_depth_multiplier(): data_dtype=data_dtype, kernel_shape=kernel_shape, kernel_dtype=kernel_dtype, - input_zero_point=5, - kernel_zero_point=3, + input_zero_point=input_zero_point, + kernel_zero_point=kernel_zero_point, input_scale=1.0, kernel_scale=1.0, kernel_size=(3, 3), @@ -1095,19 +1097,19 @@ def test_per_channel_kernel_scale(): if __name__ == "__main__": - #test_no_zero_point() - #test_input_zero_point() - #test_kernel_zero_point() - #test_both_zero_point() - #test_layout() - #test_padding() - #test_dilation() - #test_const_folding() - #test_kernel_size_1x1() - #test_kernel_size_1x1_strides_2() - #test_tflite_large_irregular() - #test_broadcast_layout() - #test_tflite_output_multiplier_greater_than_one() - #test_tflite_anistropic_strides() + test_no_zero_point() + test_input_zero_point() + test_kernel_zero_point() + test_both_zero_point() + test_layout() + test_padding() + test_dilation() + test_const_folding() + test_kernel_size_1x1() + test_kernel_size_1x1_strides_2() + test_tflite_large_irregular() + test_broadcast_layout() + test_tflite_output_multiplier_greater_than_one() + test_tflite_anistropic_strides() test_depthwise_depth_multiplier() - #test_per_channel_kernel_scale() + test_per_channel_kernel_scale() From a4d1b62b6b37a7599c42040112832663197c75dd Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 2 Aug 2021 18:25:21 +0000 Subject: [PATCH 08/10] Formatting. --- src/relay/qnn/op/convolution.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 2d88ebe2456e..cf5266485f2e 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -717,7 +717,9 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, kernel_zero_point_int = GetScalarFromConstant(kernel_zero_point); } else { // Make kernel_zero_point expression a 1-D tensor for consistent shape. - kernel_zero_point = Reshape(kernel_zero_point, {-1,}); + kernel_zero_point = Reshape(kernel_zero_point, { + -1, + }); dynamic_zp = true; } From ada06e0194c00749d0be863d6af1ce9aaf425e97 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 2 Aug 2021 19:08:14 +0000 Subject: [PATCH 09/10] Fix lint --- python/tvm/relay/qnn/op/legalizations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 76da216e5534..3226240fbe39 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -210,9 +210,9 @@ def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op): # Helper function to change dtypes to int8 x int8. Cuda dp4a instructions prefer this setting. def helper_change_dtypes_to_int8(attrs, inputs, types, relay_op): - """Legalizes QNN conv2d/dense op for Nvidia HW. dp4a supports i8 x i8 fast conv/MM. If the dtypes - are already good, we dont transform. Else, we shift the tensor values and zero points to change - the dtype. + """Legalizes QNN conv2d/dense op for Nvidia HW. dp4a supports i8 x i8 fast conv/MM. If the + dtypes are already good, we dont transform. Else, we shift the tensor values and zero points + to change the dtype. Parameters ---------- From 8f7e04820bd64123c9f016e73956c4e0ed7ffa69 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 6 Aug 2021 16:40:19 +0000 Subject: [PATCH 10/10] Merge with new test parameterization. --- tests/python/frontend/onnx/test_forward.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 5eb861726e59..8422cda42afc 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4919,8 +4919,6 @@ def verify_eyelike(indata): target_skips = { "cuda": [ - "test_basic_convinteger", - "test_convinteger_with_padding", "test_range_float_type_positive_delta_expanded", "test_range_int32_type_positive_delta_expanded", "test_mod_mixed_sign_float16", @@ -5443,8 +5441,7 @@ def verify_convinteger( ) model = helper.make_model(graph, producer_name="convinteger_test") # opt_level=1 will cause error - verify_with_ort_with_inputs(model, input_values, opt_level=2) - + verify_with_ort_with_inputs(model, input_values, target=target, dev=dev, opt_level=2) def repeat(N, D): return tuple([N for _ in range(D)])