diff --git a/CMakeLists.txt b/CMakeLists.txt index 18f58c8ccb9c..830d4d2e0014 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -95,6 +95,7 @@ if(MSVC) add_definitions(-D_ENABLE_EXTENDED_ALIGNED_STORAGE) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /bigobj") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /bigobj") if(USE_MSVC_MT) foreach(flag_var diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index 4b5cd89f0b0c..1d840f8972e7 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -32,6 +32,25 @@ namespace tvm { namespace relay { namespace qnn { +/*! \brief Attribute for qnn add operator */ +struct QnnAddAttrs : public tvm::AttrsNode { + std::string rounding; + + TVM_DECLARE_ATTRS(QnnAddAttrs, "relay.attrs.QnnAddAttrs") { + TVM_ATTR_FIELD(rounding).set_default("UPWARD").describe( + "Defines the rounding direction when the value is midway between" + "two representable values. There are two 3 modes - UPWARD, TONEAREST" + "or TFLITE. UP/TONEAREST modes behave exactly same except at the" + "midpoints between the two representable values. At the midpoint," + "UPWARD rounds towards positive infinity (for example -1.5 will be" + "rounded to -1). TONEAREST is the standard rounding where the" + "value is rounded away from zero at midpoints (for example, -1.5" + "rounds to -2). More context can be found at following glibc manual" + "https://www.gnu.org/software/libc/manual/html_node/Rounding.html." + "TFLITE mode is more complicated, referring to tflite implementation."); + } +}; + /*! \brief Attribute for requantize operator */ struct RequantizeAttrs : public tvm::AttrsNode { int axis; @@ -46,14 +65,15 @@ struct RequantizeAttrs : public tvm::AttrsNode { .set_default(-1); TVM_ATTR_FIELD(rounding).set_default("UPWARD").describe( "Defines the rounding direction when the value is midway between" - "two representable values. There are two supported modes - UPWARD" - "or TONEAREST. Both modes behave exactly same except at the" + "two representable values. There are two 3 modes - UPWARD, TONEAREST" + "or TFLITE. UP/TONEAREST modes behave exactly same except at the" "midpoints between the two representable values. At the midpoint," "UPWARD rounds towards positive infinity (for example -1.5 will be" "rounded to -1). TONEAREST is the standard rounding where the" "value is rounded away from zero at midpoints (for example, -1.5" - "rounds to -2). More context can be found at following gblic manual" - "https://www.gnu.org/software/libc/manual/html_node/Rounding.html."); + "rounds to -2). More context can be found at following glibc manual" + "https://www.gnu.org/software/libc/manual/html_node/Rounding.html." + "TFLITE mode is more complicated, referring to tflite implementation."); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 5a645c67cf61..2efa52e0ba8f 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -45,7 +45,7 @@ def __init__(self, tensor_idx, tensor, buffer, qnn_params=None): class OperatorConverter(object): """Operator Converted for converting TFLite ops to Relay ops""" - def __init__(self, model, subgraph, exp_tab): + def __init__(self, model, subgraph, exp_tab, rounding): try: from tflite.BuiltinOperator import BuiltinOperator @@ -60,6 +60,7 @@ def __init__(self, model, subgraph, exp_tab): self.builtin_op_code = build_str_map(BuiltinOperator()) self.activation_fn_type = build_str_map(ActivationFunctionType()) self.builtin_options = build_str_map(BuiltinOptions()) + self.rounding = rounding # Add more operators self.convert_map = { @@ -643,6 +644,9 @@ def _hard_swish(data): return out + # TODO in quantized mode, concat op implicitly invokes requantize in cpp + # implementation, TFLITE mode rounding needed to be selected in order + # to get bit-exact execution. def convert_concatenation(self, op): """Convert TFLite concatenation""" try: @@ -854,14 +858,26 @@ def _convert_elemwise(self, relay_op, op): if lhs_tensor.qnn_params: assert rhs_tensor.qnn_params, "Both tensors should be quantized." assert output_tensor.qnn_params, "Output tensor should be quantized." - out = relay_op(lhs=lhs_expr, - rhs=rhs_expr, - lhs_scale=lhs_tensor.qnn_params['scale'], - lhs_zero_point=lhs_tensor.qnn_params['zero_point'], - rhs_scale=rhs_tensor.qnn_params['scale'], - rhs_zero_point=rhs_tensor.qnn_params['zero_point'], - output_scale=output_tensor.qnn_params['scale'], - output_zero_point=output_tensor.qnn_params['zero_point']) + has_tflite_rounding_mode = [_qnn.op.add] + if relay_op in has_tflite_rounding_mode: + out = relay_op(lhs=lhs_expr, + rhs=rhs_expr, + lhs_scale=lhs_tensor.qnn_params['scale'], + lhs_zero_point=lhs_tensor.qnn_params['zero_point'], + rhs_scale=rhs_tensor.qnn_params['scale'], + rhs_zero_point=rhs_tensor.qnn_params['zero_point'], + output_scale=output_tensor.qnn_params['scale'], + output_zero_point=output_tensor.qnn_params['zero_point'], + rounding=self.rounding) + else: + out = relay_op(lhs=lhs_expr, + rhs=rhs_expr, + lhs_scale=lhs_tensor.qnn_params['scale'], + lhs_zero_point=lhs_tensor.qnn_params['zero_point'], + rhs_scale=rhs_tensor.qnn_params['scale'], + rhs_zero_point=rhs_tensor.qnn_params['zero_point'], + output_scale=output_tensor.qnn_params['scale'], + output_zero_point=output_tensor.qnn_params['zero_point']) else: out = relay_op(lhs_expr, rhs_expr) @@ -924,6 +940,9 @@ def convert_sub(self, op): return self._convert_elemwise(_qnn.op.subtract, op) return self._convert_elemwise(_op.subtract, op) + # TODO in quantized mode, mul op implicitly invokes requantize in cpp + # implementation, TFLITE mode rounding needed to be selected in order + # to get bit-exact execution. def convert_mul(self, op): """Convert TFLite MUL""" # Check if the input tensor is quantized, call QNN op @@ -1327,6 +1346,7 @@ def _convert_reduce(self, relay_op, op): input_zero_point=input_tensor.qnn_params['zero_point'], output_scale=output_tensor.qnn_params['scale'], output_zero_point=output_tensor.qnn_params['zero_point'], + rounding=self.rounding, out_dtype=output_tensor_type_str) return out @@ -1452,6 +1472,7 @@ def convert_fully_connected(self, op): input_zero_point=new_input_zero_point, output_scale=output_tensor.qnn_params['scale'], output_zero_point=output_tensor.qnn_params['zero_point'], + rounding=self.rounding, out_dtype=output_tensor_type_str) # Call activation function @@ -1667,6 +1688,7 @@ def convert_conv(self, op, conv_type): input_zero_point=new_input_zero_point, output_scale=output_tensor.qnn_params['scale'], output_zero_point=output_tensor.qnn_params['zero_point'], + rounding=self.rounding, out_dtype=output_tensor_type_str) # Call activation function @@ -2454,8 +2476,10 @@ def get_scalar_from_constant(expr): assert isinstance(expr, _expr.Constant) and not expr.data.shape, \ "Expr is not a constant scalar." value = expr.data.asnumpy() - assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \ - "value must be float32/int32" + assert value.dtype == np.dtype(np.int32) or \ + value.dtype == np.dtype(np.float32) or \ + value.dtype == np.dtype(np.float64), \ + "value must be float32/float64/int32" return np.asscalar(value) @@ -2524,7 +2548,7 @@ def get_tensor_name(subgraph, tensor_idx): return subgraph.Tensors(tensor_idx).Name().decode("utf-8") -def from_tflite(model, shape_dict, dtype_dict): +def from_tflite(model, shape_dict, dtype_dict, rounding='TFLITE'): """Convert from tflite model into compatible relay Function. Parameters @@ -2538,6 +2562,9 @@ def from_tflite(model, shape_dict, dtype_dict): dtype_dict : dict of str to str Input types of the model. + rounding : str + Rounding mode for tflite model + Returns ------- mod : tvm.IRModule @@ -2576,7 +2603,7 @@ def from_tflite(model, shape_dict, dtype_dict): exp_tab.set_expr(model_input_name, _expr.var(model_input_name, shape=shape, dtype=dtype)) # op code in model - op_converter = OperatorConverter(model, subgraph, exp_tab) + op_converter = OperatorConverter(model, subgraph, exp_tab, rounding) op_converter.check_unsupported_ops() op_converter.convert_op_to_relay() diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 5a3106d1e787..f6cdc6afb1be 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -300,7 +300,8 @@ def add(lhs, rhs_scale, rhs_zero_point, output_scale, - output_zero_point): + output_zero_point, + rounding="UPWARD"): """Quantized addition with numpy-style broadcasting. Parameters @@ -329,6 +330,9 @@ def add(lhs, output_zero_point: relay.Expr The zero point of output quantized expr. + rounding: str, optional + rounding mode of qnn add + Returns ------- result : relay.Expr @@ -338,7 +342,8 @@ def add(lhs, return _make.add(lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, - output_scale, output_zero_point) + output_scale, output_zero_point, + rounding) def dense(data, diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index e5d54099b07c..4d107b56b1ee 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -172,7 +172,7 @@ def qconfig(**kwargs): is None, which means will try to call all operartors' annotate rewrite function. - rounding: "UPWARD" or "TONEAREST" + rounding: "UPWARD" or "TONEAREST" or "TFLITE" Rounding direction for fixed point multiplications. Returns diff --git a/src/relay/qnn/op/add.cc b/src/relay/qnn/op/add.cc index b0dc3e4af5c4..c92b11dfc4cf 100644 --- a/src/relay/qnn/op/add.cc +++ b/src/relay/qnn/op/add.cc @@ -30,9 +30,11 @@ namespace tvm { namespace relay { namespace qnn { +TVM_REGISTER_NODE_TYPE(QnnAddAttrs); + /* * \brief Canonicalizes the QNN add op. - * \param attrs The empty attribute. + * \param attrs The QNN add attrs. * \param new_args The new mutated args to the call node. * \param arg_types The types of input and output. * \return The sequence of Relay ops for add op. @@ -42,9 +44,55 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array& new_args, // Get the args. QnnBinaryOpArguments args(new_args); + // Get the attrs. + const QnnAddAttrs* add_attrs = attrs.as(); + CHECK(add_attrs != nullptr); + auto& rounding = add_attrs->rounding; + // Get the input dtype and shape. QnnBinaryOpTensorType input_type(arg_types, 0); + if (rounding == "TFLITE") { + double lhs_scale_val = GetScalarFromConstant(args.lhs_scale); + double rhs_scale_val = GetScalarFromConstant(args.rhs_scale); + double out_scale_val = GetScalarFromConstant(args.output_scale); + double twice_max_input_scale = 2 * std::max(lhs_scale_val, rhs_scale_val); + double real_lhs_scale_val = lhs_scale_val / twice_max_input_scale; + double real_rhs_scale_val = rhs_scale_val / twice_max_input_scale; + double real_out_scale_val = twice_max_input_scale / ((1 << 20) * out_scale_val); + + auto real_lhs_scale = MakeConstantScalar(DataType::Float(64), real_lhs_scale_val); + auto real_rhs_scale = MakeConstantScalar(DataType::Float(64), real_rhs_scale_val); + auto real_out_scale = MakeConstantScalar(DataType::Float(64), real_out_scale_val); + auto one_scalar = MakeConstantScalar(DataType::Float(64), 1); + auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0); + auto left_shift_scalar = MakeConstantScalar(DataType::Int(32), 1 << 20); + + Expr adapted_lhs = Cast(args.lhs, DataType::Int(32)); + if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) { + adapted_lhs = Subtract(adapted_lhs, Cast(args.lhs_zero_point, DataType::Int(32))); + } + adapted_lhs = Multiply(adapted_lhs, left_shift_scalar); + + Expr adapted_rhs = Cast(args.rhs, DataType::Int(32)); + if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) { + adapted_rhs = Subtract(adapted_rhs, Cast(args.rhs_zero_point, DataType::Int(32))); + } + adapted_rhs = Multiply(adapted_rhs, left_shift_scalar); + + auto requantized_lhs = Requantize(adapted_lhs, input_type.shape, real_lhs_scale, zero_scalar, + one_scalar, zero_scalar, DataType::Int(32), rounding); + + auto requantized_rhs = Requantize(adapted_rhs, input_type.shape, real_rhs_scale, zero_scalar, + one_scalar, zero_scalar, DataType::Int(32), rounding); + + auto output = Add(requantized_lhs, requantized_rhs); + output = Requantize(output, input_type.shape, real_out_scale, zero_scalar, one_scalar, + args.output_zero_point, DataType::Int(32), rounding); + // Go back to lower precision. + return ConvertDtype(output, input_type.dtype); + } + // FIXME (anijain2305) - The lowering can be further optimized. Instead of inserting requantize in // the start, we can insert requantize at the end if both input tensors have same qnn params. In // that case, we can first add the tensors, subtract the zero point, and requantize at the end. @@ -86,9 +134,23 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array& new_args, return ConvertDtype(output, input_type.dtype); } +Expr MakeQnnAdd(Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, + Expr rhs_zero_point, Expr output_scale, Expr output_zero_point, + std::string rounding) { + auto attrs = make_object(); + attrs->rounding = std::move(rounding); + + static const Op& op = Op::Get("qnn.add"); + return Call(op, + {lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, + output_zero_point}, + Attrs(attrs), {}); +} + // QNN Addition operator. -QNN_REGISTER_BINARY_OP("add") +QNN_REGISTER_BINARY_OP_WITH_BODY("add", MakeQnnAdd) .describe("Elementwise add with with broadcasting for quantized tensors.") + .set_attrs_type() .set_support_level(11) .set_attr("FTVMQnnCanonicalize", QnnAddCanonicalize); diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc index bda8cf878793..a351147cb0cc 100644 --- a/src/relay/qnn/op/concatenate.cc +++ b/src/relay/qnn/op/concatenate.cc @@ -48,7 +48,10 @@ bool QnnConcatenateRel(const Array& types, int num_inputs, const Attrs& at << PrettyPrint(types[1])); } for (const auto& input_scale : input_scales_tuple->fields) { - CHECK(IsScalarType(input_scale, DataType::Float(32))); // input_scales[idx] + const auto* input_scale_type = input_scale.as(); + CHECK(input_scale_type && IsScalarType(input_scale, input_scale_type->dtype) && + (input_scale_type->dtype == DataType::Float(32) || + input_scale_type->dtype == DataType::Float(64))); // input_scale[idx] } const auto* input_zero_points_tuple = types[2].as(); @@ -61,8 +64,12 @@ bool QnnConcatenateRel(const Array& types, int num_inputs, const Attrs& at CHECK(IsScalarType(input_zero_point, DataType::Int(32))); // input_zero_points[idx] } - CHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale - CHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point + const auto* output_scale_type = types[3].as(); + CHECK(output_scale_type && IsScalarType(types[3], output_scale_type->dtype) && + (output_scale_type->dtype == DataType::Float(32) || + output_scale_type->dtype == DataType::Float(64))); // output_scale + + CHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay // Concatenate infer type function. diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index ae52a42e42b8..29d5d6b0d217 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -57,13 +57,18 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; // Check the types of scale and zero points. - CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point - CHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point - CHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale + CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point + CHECK(IsScalarType(types[3], DataType::Int(32))); // kernel_zero_point + const auto* input_scale_type = types[4].as(); + CHECK(input_scale_type && IsScalarType(types[4], input_scale_type->dtype) && + (input_scale_type->dtype == DataType::Float(32) || + input_scale_type->dtype == DataType::Float(64))); // input_scale + // Kernel scale can be a vector of length output_channels or a scalar. size_t axis = param->kernel_layout.find('O'); CHECK(axis != std::string::npos) << "Kernel layout attribute is not defined"; - AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter); // kernel scale + const auto* kernel_scale_type = types[5].as(); + AssignType(types[5], kernel_scale_type->dtype, weight->shape[axis], reporter); // kernel scale // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay // Conv2D infer type function. diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 7c014d71a76a..5e00d835fbb2 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -46,8 +46,10 @@ bool DequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, << input_dtype; // Check the types of scale and zero points. - CHECK(IsScalarType(types[1], DataType::Float(32))); // input_scale - CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point + const auto* input_scale_type = types[1].as(); + CHECK(input_scale_type && (input_scale_type->dtype == DataType::Float(32) || + input_scale_type->dtype == DataType::Float(64))); + CHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point const Array oshape = data->shape; // assign output type, output will always be float 32. @@ -66,7 +68,8 @@ Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point) { Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale, const Expr& input_zero_point) { auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), input_zero_point); - auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), input_scale); + auto scaled_output = + Multiply(Cast(shift, DataType::Float(32)), Cast(input_scale, DataType::Float(32))); return scaled_output; } diff --git a/src/relay/qnn/op/mul.cc b/src/relay/qnn/op/mul.cc index ec74b799407b..c947e9b723ee 100644 --- a/src/relay/qnn/op/mul.cc +++ b/src/relay/qnn/op/mul.cc @@ -49,7 +49,7 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const Array& new_args, QnnBinaryOpTensorType input_type(arg_types, 0); // data types const auto int32_dtype = DataType::Int(32); - const auto float32_dtype = DataType::Float(32); + const auto float64_dtype = DataType::Float(64); /* A tensor multiplication c = a * b can be written in terms of respective @@ -79,10 +79,24 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const Array& new_args, auto output = Multiply(lhs_shifted, rhs_shifted); // Get the adjusted new scale and zero points. - float lhs_scale_float = GetScalarFromConstant(args.lhs_scale); - float rhs_scale_float = GetScalarFromConstant(args.rhs_scale); - float new_scale_float = lhs_scale_float * rhs_scale_float; - auto new_input_scale = MakeConstantScalar(float32_dtype, new_scale_float); + auto lhs_scale_dtype = GetDataTypeFromConstant(args.lhs_scale); + auto rhs_scale_dtype = GetDataTypeFromConstant(args.rhs_scale); + double lhs_scale_val = -1.f, rhs_scale_val = -1.f; + if (lhs_scale_dtype == DataType::Float(32)) { + lhs_scale_val = GetScalarFromConstant(args.lhs_scale); + } else if (lhs_scale_dtype == DataType::Float(64)) { + lhs_scale_val = GetScalarFromConstant(args.lhs_scale); + } + + if (rhs_scale_dtype == DataType::Float(32)) { + rhs_scale_val = GetScalarFromConstant(args.rhs_scale); + } else if (rhs_scale_dtype == DataType::Float(64)) { + rhs_scale_val = GetScalarFromConstant(args.rhs_scale); + } + CHECK(lhs_scale_val != -1.f && rhs_scale_val != -1.f); + + double new_scale_val = lhs_scale_val * rhs_scale_val; + auto new_input_scale = MakeConstantScalar(float64_dtype, new_scale_val); auto new_input_zero_point = zero_scalar; // Requantize to get Q_c diff --git a/src/relay/qnn/op/op_common.h b/src/relay/qnn/op/op_common.h index 50fc0cda30cf..04e06aa839cb 100644 --- a/src/relay/qnn/op/op_common.h +++ b/src/relay/qnn/op/op_common.h @@ -171,12 +171,21 @@ static inline bool QnnBroadcastRel(const Array& types, int num_inputs, con CHECK_EQ(types.size(), kNumQnnBinaryOpArgTypes); // Check the scale and zero point types - CHECK(IsScalarType(types[2], DataType::Float(32))); // lhs_scale - CHECK(IsScalarType(types[3], DataType::Int(32))); // lhs_zero_point - CHECK(IsScalarType(types[4], DataType::Float(32))); // rhs_scale - CHECK(IsScalarType(types[5], DataType::Int(32))); // rhs_zero_point - CHECK(IsScalarType(types[6], DataType::Float(32))); // output_scale - CHECK(IsScalarType(types[7], DataType::Int(32))); // output_zero_point + const auto* lhs_scale_type = types[2].as(); + const auto* rhs_scale_type = types[4].as(); + const auto* out_scale_type = types[6].as(); + CHECK(lhs_scale_type && IsScalarType(types[2], lhs_scale_type->dtype) && + (lhs_scale_type->dtype == DataType::Float(32) || + lhs_scale_type->dtype == DataType::Float(64))); // lhs_scale + CHECK(rhs_scale_type && IsScalarType(types[4], rhs_scale_type->dtype) && + (rhs_scale_type->dtype == DataType::Float(32) || + rhs_scale_type->dtype == DataType::Float(64))); // lhs_scale + CHECK(out_scale_type && IsScalarType(types[6], out_scale_type->dtype) && + (out_scale_type->dtype == DataType::Float(32) || + out_scale_type->dtype == DataType::Float(64))); // lhs_scale + CHECK(IsScalarType(types[3], DataType::Int(32))); // lhs_zero_point + CHECK(IsScalarType(types[5], DataType::Int(32))); // rhs_zero_point + CHECK(IsScalarType(types[7], DataType::Int(32))); // output_zero_point // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay // BroadcastRel infer type function. @@ -194,29 +203,32 @@ static inline bool QnnBroadcastRel(const Array& types, int num_inputs, con * * \param OpName the name of registry. */ -#define QNN_REGISTER_BINARY_OP(OpName) \ - TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName) \ - .set_body_typed([](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \ - Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \ - static const Op& op = Op::Get("qnn." OpName); \ - return Call(op, \ - {lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, \ - output_zero_point}, \ - Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP("qnn." OpName) \ - .set_num_inputs(kNumQnnBinaryOpInputs) \ - .add_argument("lhs", "Tensor", "The left hand side quantized tensor.") \ - .add_argument("rhs", "Tensor", "The right hand side quantized tensor.") \ - .add_argument("lhs_scale", "Tensor", "The scale of the lhs tensor.") \ - .add_argument("lhs_zero_point", "Tensor", "The zero_point of the lhs tensor.") \ - .add_argument("rhs_scale", "Tensor", "The scale of the rhs tensor.") \ - .add_argument("rhs_zero_point", "Tensor", "The zero_point of the rhs tensor.") \ - .add_argument("output_scale", "Tensor", "The scale of the output tensor.") \ - .add_argument("output_zero_point", "Tensor", "The zero_point of the output tensor.") \ - .add_type_rel("QnnBroadcast", QnnBroadcastRel) \ +#define QNN_REGISTER_BINARY_OP_WITH_BODY(OpName, Body) \ + TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName).set_body_typed(Body); \ + RELAY_REGISTER_OP("qnn." OpName) \ + .set_num_inputs(kNumQnnBinaryOpInputs) \ + .add_argument("lhs", "Tensor", "The left hand side quantized tensor.") \ + .add_argument("rhs", "Tensor", "The right hand side quantized tensor.") \ + .add_argument("lhs_scale", "Tensor", "The scale of the lhs tensor.") \ + .add_argument("lhs_zero_point", "Tensor", "The zero_point of the lhs tensor.") \ + .add_argument("rhs_scale", "Tensor", "The scale of the rhs tensor.") \ + .add_argument("rhs_zero_point", "Tensor", "The zero_point of the rhs tensor.") \ + .add_argument("output_scale", "Tensor", "The scale of the output tensor.") \ + .add_argument("output_zero_point", "Tensor", "The zero_point of the output tensor.") \ + .add_type_rel("QnnBroadcast", QnnBroadcastRel) \ .set_attr("FInferCorrectLayout", QnnBinaryBroadcastLayout) +#define QNN_REGISTER_BINARY_OP(OpName) \ + auto DefaultBody = [](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \ + Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \ + static const Op& op = Op::Get("qnn." OpName); \ + return Call(op, \ + {lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, \ + output_zero_point}, \ + Attrs(), {}); \ + }; \ + QNN_REGISTER_BINARY_OP_WITH_BODY(OpName, DefaultBody) + } // namespace qnn } // namespace relay } // namespace tvm diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 28f0b8994a01..3d173e560453 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -53,8 +53,11 @@ bool QuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range"; // Check and assign types for scale and zero points. - AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale - AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // zero point + const auto* scale_type = types[1].as(); + CHECK(scale_type && + (scale_type->dtype == DataType::Float(32) || scale_type->dtype == DataType::Float(64))); + AssignType(types[1], scale_type->dtype, data->shape[axis], reporter); // scale + AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // zero point const Array oshape = data->shape; const DataType out_dtype = quantize_attrs->out_dtype; @@ -98,7 +101,7 @@ Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale, const int32_t min_val = GetQmin(out_dtype); const int32_t max_val = GetQmax(out_dtype); - auto scale_data = Divide(input_tensor, expanded_output_scale); + auto scale_data = Divide(input_tensor, Cast(expanded_output_scale, DataType::Float(32))); auto add_zero_point = Cast(Round(Add(scale_data, Cast(expanded_output_zero_point, DataType::Float(32)))), DataType::Int(32)); diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 79cb08d3f948..ca617f4ef551 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -133,7 +133,14 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, const Expr& input_zero_point, const Expr& output_scale, const Expr& output_zero_point, const RequantizeAttrs* param, const Array& input_shape, const DataType& out_dtype) { - auto tensor = Cast(input_tensor, DataType::Int(32)); + bool input_is_int32 = false; + if ((input_tensor.as() || input_tensor.as()) && + input_tensor->checked_type_.defined()) { + auto tensor_type = input_tensor->checked_type().as(); + if (tensor_type && tensor_type->dtype == DataType::Int(32)) input_is_int32 = true; + } + auto tensor = input_is_int32 ? input_tensor : Cast(input_tensor, DataType::Int(32)); + // auto tensor = Cast(input_tensor, DataType::Int(32)); // 1) Subtract the input_zero_point auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0); if (!IsEqualScalar(input_zero_point, zero_scalar)) { @@ -145,26 +152,48 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, // the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input // tensor. Depending on the quantization type, the fixed point multiplication routing is called. auto scaled_int32_t = tensor; - float output_scale_float = GetScalarFromConstant(output_scale); + auto out_scale_dtype = GetDataTypeFromConstant(output_scale); + double output_scale_val = -1.0f; + if (out_scale_dtype == DataType::Float(64)) { + output_scale_val = GetScalarFromConstant(output_scale); + } else if (out_scale_dtype == DataType::Float(32)) { + output_scale_val = GetScalarFromConstant(output_scale); + } + CHECK_GE(output_scale_val, 0.0f); + + auto in_scale_dtype = GetDataTypeFromConstant(input_scale); if (IsConstScalar(input_scale)) { // This is per-tensor quantization. Single scale. - float input_scale_float = GetScalarFromConstant(input_scale); - double double_multiplier = - static_cast(input_scale_float) / static_cast(output_scale_float); + double input_scale_val = -1.0f; + if (in_scale_dtype == DataType::Float(64)) { + input_scale_val = GetScalarFromConstant(input_scale); + } else if (in_scale_dtype == DataType::Float(32)) { + input_scale_val = GetScalarFromConstant(input_scale); + } + double double_multiplier = input_scale_val / output_scale_val; // Skip if input and output scales are same. - if (!IsEqualScalar(input_scale, output_scale)) { + if (input_scale_val != output_scale_val) { scaled_int32_t = FixedPointMultiply(scaled_int32_t, double_multiplier, input_shape, param->rounding); } } else { // This is per-channel (per=axis) quantization. std::vector double_multipliers; - auto input_axis_scales = GetFloatVectorFromConstant(input_scale); - for (auto input_axis_scale : input_axis_scales) { - double multiplier = - static_cast(input_axis_scale) / static_cast(output_scale_float); - double_multipliers.push_back(multiplier); + if (in_scale_dtype == DataType::Float(64)) { + auto input_axis_scales = GetDoubleVectorFromConstant(input_scale); + for (auto input_axis_scale : input_axis_scales) { + double multiplier = input_axis_scale / output_scale_val; + double_multipliers.push_back(multiplier); + } + } else if (in_scale_dtype == DataType::Float(32)) { + auto input_axis_scales = GetFloatVectorFromConstant(input_scale); + for (auto input_axis_scale : input_axis_scales) { + double multiplier = input_axis_scale / output_scale_val; + double_multipliers.push_back(multiplier); + } } + CHECK_GT(double_multipliers.size(), 0); + int axis = param->axis; axis = (axis == -1) ? input_shape.size() - 1 : axis; scaled_int32_t = FixedPointMultiplyPerChannel(scaled_int32_t, double_multipliers, input_shape, @@ -229,9 +258,10 @@ Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array& new_args, auto out_dtype = out_tensor_type->dtype; // Check rounding validity. - CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST") - << "QNN requantize supports two rounding modes - UPWARD and " - << "TONEAREST"; + CHECK(param->rounding == "UPWARD" || param->rounding == "TONEAREST" || + param->rounding == "TFLITE") + << "QNN requantize supports 3 rounding modes - UPWARD, " + << "TONEAREST and TFLITE"; return RequantizeLower(quantized_data, input_scale, input_zero_point, output_scale, output_zero_point, param, input_shape, out_dtype); } @@ -262,11 +292,17 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK_GE(axis, 0) << "axis " << requantize_attrs->axis << " is out of range"; // Check and assign types for scale and zero points. - AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // input_scale - AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // input_zero_pt + const auto* input_scale_type = types[1].as(); + CHECK(input_scale_type && (input_scale_type->dtype == DataType::Float(32) || + input_scale_type->dtype == DataType::Float(64))); + AssignType(types[1], input_scale_type->dtype, data->shape[axis], reporter); // input_scale + AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // input_zero_pt // For now, requantize output tensor is limited to full tensor uniform quantization. - CHECK(IsScalarType(types[3], DataType::Float(32))); // output_scale - CHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point + const auto* output_scale_type = types[3].as(); + CHECK(output_scale_type && IsScalarType(types[3], output_scale_type->dtype) && + (output_scale_type->dtype == DataType::Float(32) || + output_scale_type->dtype == DataType::Float(64))); // output_scale + CHECK(IsScalarType(types[4], DataType::Int(32))); // output_zero_point const Array oshape = data->shape; // assign output type diff --git a/src/relay/qnn/util.cc b/src/relay/qnn/util.cc index 7171ded765b9..946c2c0a1d9e 100644 --- a/src/relay/qnn/util.cc +++ b/src/relay/qnn/util.cc @@ -24,12 +24,52 @@ #include "util.h" +#include + #include "../transforms/pattern_util.h" namespace tvm { namespace relay { namespace qnn { +/* \brief This function implements the rounding part of ARMv7 NEON VQRDMULH + * instruction. For code reuse, the multiplied tensor is directly passed in + * as parameter. + */ +Expr SaturatingRoundingDoublingHigh32(const Expr& input_tensor, const Expr& multiplier_expr, + const Expr& scaled_tensor, + const Array& input_shape, + bool possible_to_overflow = true) { + DataType hp_dtype = DataType::Int(64); + DataType lp_dtype = DataType::Int(32); + int64_t pos_nudge_value = (1ll << 30); + int64_t neg_nudge_value = 1 - (1ll << 30); + auto pos_nudge = MakeConstantScalar(hp_dtype, pos_nudge_value); + auto neg_nudge = MakeConstantScalar(hp_dtype, neg_nudge_value); + auto pos_nudge_t = Full(pos_nudge, input_shape, hp_dtype); + auto neg_nudge_t = Full(neg_nudge, input_shape, hp_dtype); + + auto dividend = MakeConstantScalar(hp_dtype, 1ll << 31); + + auto zero_t = Zeros(input_shape, hp_dtype); + auto nudged_tensor_t = + Add(scaled_tensor, Where(GreaterEqual(scaled_tensor, zero_t), pos_nudge_t, neg_nudge_t)); + auto high32_t = Cast(Divide(nudged_tensor_t, dividend), lp_dtype); + + if (possible_to_overflow) { + auto int32_min = MakeConstantScalar(lp_dtype, std::numeric_limits::min()); + auto int32_max = MakeConstantScalar(lp_dtype, std::numeric_limits::max()); + auto int32_max_t = Full(int32_max, input_shape, lp_dtype); + auto int32_min_t = Full(int32_min, input_shape, lp_dtype); + + auto overflow_t = + LogicalAnd(Equal(input_tensor, int32_min_t), Equal(multiplier_expr, int32_min_t)); + return Where(overflow_t, int32_max_t, high32_t); + } else { + return high32_t; + } +} + /* * \brief Convert FP32 representation into fixed point representation. * \param double_multplier The input FP32 number. @@ -80,6 +120,7 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& // Choose high precision datatype to be int64. This is for avoiding overflow // in multiplication of two int32 values. DataType hp_dtype = DataType::Int(64); + DataType lp_dtype = DataType::Int(32); tensor = Cast(tensor, hp_dtype); // 1) Calculating the integer multiplier and integer shift @@ -100,7 +141,7 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& // (from the right, rightmost bit is bit 0). The computation is performed in // higher precision to avoid overflow in multiplying two int32 values. Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier); - tensor = Multiply(tensor, scalar); + Expr scaled_tensor = Multiply(tensor, scalar); // 4) Find the rounding scalar. This depends on where the final decimal // point sits. As we will be right shifting the multiplied_t, we need to @@ -108,25 +149,51 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& int total_right_shift = right_shift + 31; int64_t pos_rounding_value = (1ll << (total_right_shift - 1)); + // This lambda function gathers some shared logic in "TONEAREST" and "TFLITE" + // rounding scheme, which calculates a rounder tensor according to the sign + // of values in the tensor to be rounded. + auto nearest_rounding_scalar = [&](const Expr& input_tensor, int right_shift, + DataType dtype) -> Expr { + int64_t pos_rounding_value = (1ll << (right_shift - 1)); + auto pos_rounder = MakeConstantScalar(dtype, pos_rounding_value); + auto neg_rounder = MakeConstantScalar(dtype, pos_rounding_value - 1); + auto pos_rounder_t = Full(pos_rounder, input_shape, dtype); + auto neg_rounder_t = Full(neg_rounder, input_shape, dtype); + + auto zero_t = Zeros(input_shape, dtype); + return Where(GreaterEqual(input_tensor, zero_t), pos_rounder_t, neg_rounder_t); + }; + Expr round_scalar; if (rounding == "UPWARD") { round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value); } else if (rounding == "TONEAREST") { - auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value); - auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1); - auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype); - auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype); + round_scalar = nearest_rounding_scalar(scaled_tensor, total_right_shift, hp_dtype); + } else if (rounding == "TFLITE") { + auto scalar_t = Full(scalar, input_shape, hp_dtype); + bool possible_to_overflow = fixed_point_multiplier == std::numeric_limits::min(); + auto high32_t = SaturatingRoundingDoublingHigh32(tensor, scalar_t, scaled_tensor, input_shape, + possible_to_overflow); - auto zero_t = Zeros(input_shape, hp_dtype); - round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); + if (right_shift <= 0) { + scaled_tensor = high32_t; + } else { + auto zero_t = Zeros(input_shape, lp_dtype); + round_scalar = nearest_rounding_scalar(high32_t, right_shift, lp_dtype); + scaled_tensor = Add(high32_t, round_scalar); + auto rshift_expr = MakeConstantScalar(lp_dtype, right_shift); + scaled_tensor = RightShift(scaled_tensor, rshift_expr); + } + return scaled_tensor; } else { LOG(FATAL) << "Rounding mode " << rounding << " not supported."; } + // Add the rounding scalar. - tensor = Add(tensor, round_scalar); + scaled_tensor = Add(scaled_tensor, round_scalar); // 5) Simply right shift the result to get the final output. - tensor = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift)); + tensor = RightShift(scaled_tensor, MakeConstantScalar(hp_dtype, total_right_shift)); // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32. return Cast(tensor, DataType::Int(32)); @@ -144,29 +211,36 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector multipliers, // Choose high precision datatype to be int64. This is for avoiding overflow // in multiplication of two int32 values. DataType hp_dtype = DataType::Int(64); + DataType lp_dtype = DataType::Int(32); tensor = Cast(tensor, hp_dtype); // 1) Calculating the integer multiplier and integer shift. These are calculated per axis/per // channel. - std::vector fixed_pt_multipliers, lshifts, rshifts; - bool is_lshift_required = false; + std::vector fixed_pt_multipliers, lshifts, rshifts; + bool lshift_required = false; + bool rshift_required = false; + bool possible_to_overflow = false; for (auto multiplier : multipliers) { - int32_t fixed_pt_multiplier, shift; + int64_t fixed_pt_multiplier, shift; std::tie(fixed_pt_multiplier, shift) = GetFixedPointMultiplierShift(multiplier); - int lshift = shift > 0 ? shift : 0; - int rshift = shift > 0 ? 0 : -shift; + int64_t lshift = shift > 0 ? shift : 0; + int64_t rshift = shift > 0 ? 0 : -shift; fixed_pt_multipliers.push_back(fixed_pt_multiplier); lshifts.push_back(lshift); rshifts.push_back(rshift); - is_lshift_required = is_lshift_required | (lshift != 0); + lshift_required |= (lshift != 0); + rshift_required |= (rshift != 0); + possible_to_overflow |= (fixed_pt_multiplier == std::numeric_limits::min()); } // 2) Multiply the integer multiplier. Convert lefts shifts into expr and multiply. - if (is_lshift_required) { + if (lshift_required) { auto lshift_expr = MakeConstantTensor(hp_dtype, {n_channels}, lshifts); auto exp_lshift_expr = ExpandBiasToMatchAxis(lshift_expr, n_dim, {channel_axis}); tensor = LeftShift(tensor, exp_lshift_expr); } + auto rshift_expr = MakeConstantTensor(lp_dtype, {n_channels}, rshifts); + auto exp_rshift_expr = ExpandBiasToMatchAxis(rshift_expr, n_dim, {channel_axis}); // 3) Perform the multiplication in higher precision. // The scalar is a fixed point value of int32 where the decimal point is @@ -177,41 +251,68 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector multipliers, auto fixed_pt_multiplier_expr = MakeConstantTensor(hp_dtype, {n_channels}, fixed_pt_multipliers); auto exp_fixed_pt_multiplier_expr = ExpandBiasToMatchAxis(fixed_pt_multiplier_expr, n_dim, {channel_axis}); - tensor = Multiply(tensor, exp_fixed_pt_multiplier_expr); + auto scaled_tensor = Multiply(tensor, exp_fixed_pt_multiplier_expr); // 4) Find the rounding scalar. This depends on where the final decimal point sits. As we will be // right shifting the multiplied_t, we need to first calculate the total_rshift. Further, we can // calculate the pos and neg rounding offset. - std::vector pos_rounding_values, neg_rounding_values, total_rshifts; + std::vector pos_rounding_values, total_rshifts; for (auto rshift : rshifts) { - int total_rshift = rshift + 31; + int64_t total_rshift = rshift + 31; total_rshifts.push_back(total_rshift); pos_rounding_values.push_back((1ll << (total_rshift - 1))); - neg_rounding_values.push_back((1ll << (total_rshift - 1)) - 1); } - // Make a Relay expr from positive and negative rounding offset values. - auto pos_rounding_value_expr = MakeConstantTensor(hp_dtype, {n_channels}, pos_rounding_values); - auto exp_pos_rounding_value_expr = - ExpandBiasToMatchAxis(pos_rounding_value_expr, n_dim, {channel_axis}); - auto neg_rounding_value_expr = MakeConstantTensor(hp_dtype, {n_channels}, neg_rounding_values); - auto exp_neg_rounding_value_expr = - ExpandBiasToMatchAxis(neg_rounding_value_expr, n_dim, {channel_axis}); + + // This lambda function gathers some shared logic in "TONEAREST" and "TFLITE" + // rounding scheme, which calculates a rounder tensor according to the sign + // of values in the tensor to be rounded. + auto nearest_rounding_tensor = [&](const Expr& input_tensor, const std::vector& rshifts, + DataType dtype) -> Expr { + std::vector pos_rounding_values, neg_rounding_values; + for (auto rshift : rshifts) { + int64_t pos_rounding_val = rshift > 0 ? (1ll << (rshift - 1)) : 0; + int64_t neg_rounding_val = rshift > 0 ? ((1ll << (rshift - 1)) - 1) : 0; + pos_rounding_values.push_back(pos_rounding_val); + neg_rounding_values.push_back(neg_rounding_val); + } + // Make a Relay expr from positive and negative rounding offset values. + auto pos_rounding_value_expr = MakeConstantTensor(dtype, {n_channels}, pos_rounding_values); + auto exp_pos_rounding_value_expr = + ExpandBiasToMatchAxis(pos_rounding_value_expr, n_dim, {channel_axis}); + auto pos_rounder = MakeBroadCastTo(exp_pos_rounding_value_expr, input_shape); + auto neg_rounding_value_expr = MakeConstantTensor(dtype, {n_channels}, neg_rounding_values); + auto exp_neg_rounding_value_expr = + ExpandBiasToMatchAxis(neg_rounding_value_expr, n_dim, {channel_axis}); + auto neg_rounder = MakeBroadCastTo(exp_neg_rounding_value_expr, input_shape); + auto zero_t = Zeros(input_shape, dtype); + return Where(GreaterEqual(input_tensor, zero_t), pos_rounder, neg_rounder); + }; Expr round_scalar; if (rounding == "UPWARD") { + // Make a Relay expr from positive and negative rounding offset values. + auto pos_rounding_value_expr = MakeConstantTensor(hp_dtype, {n_channels}, pos_rounding_values); + auto exp_pos_rounding_value_expr = + ExpandBiasToMatchAxis(pos_rounding_value_expr, n_dim, {channel_axis}); round_scalar = exp_pos_rounding_value_expr; } else if (rounding == "TONEAREST") { - // To satisfy where op shape requirements, the rounding values are broadcasted. - auto pos_rounder = MakeBroadCastTo(exp_pos_rounding_value_expr, input_shape); - auto neg_rounder = MakeBroadCastTo(exp_neg_rounding_value_expr, input_shape); - - auto zero_t = Zeros(input_shape, hp_dtype); - round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder, neg_rounder); + round_scalar = nearest_rounding_tensor(scaled_tensor, total_rshifts, hp_dtype); + } else if (rounding == "TFLITE") { + auto high32_t = SaturatingRoundingDoublingHigh32( + tensor, exp_fixed_pt_multiplier_expr, scaled_tensor, input_shape, possible_to_overflow); + if (!rshift_required) { + return high32_t; + } else { + auto zero_t = Zeros(input_shape, lp_dtype); + round_scalar = nearest_rounding_tensor(high32_t, rshifts, lp_dtype); + scaled_tensor = Add(high32_t, round_scalar); + return RightShift(scaled_tensor, exp_rshift_expr); + } } else { LOG(FATAL) << "Rounding mode " << rounding << " not supported."; } // Add the rounding scalar. - tensor = Add(tensor, round_scalar); + tensor = Add(scaled_tensor, round_scalar); // 5) Simply right shift the result to get the final output. auto total_rshift_expr = MakeConstantTensor(hp_dtype, {n_channels}, total_rshifts); diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index 736b7361a300..859463177b54 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -98,8 +98,8 @@ static inline int64_t get_const_int(const tvm::PrimExpr& x) { * \param tensor The quantized input tensor of dtype int64. * \param multiplier The scalar multiplier. * \param input_shape Shape of the input tensor. - * \param rounding "UPWARD" or "TONEAREST". The rounding direction when the value - is midway between" "two representable values. + * \param rounding "UPWARD", "TONEAREST" or "TFLITE". The rounding direction + * when the value is midway between" "two representable values. * \return The sequence of Relay ops for fixed point multiplication. * \note Original compuation is scale_fp32 * quantized_tensor. To convert into @@ -125,8 +125,8 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& * \param input_shape Shape of the input tensor. * \param channel_axis The channel_axis along which the input tensor is quantized. Default value is -1 which corresponds to the last channel_axis. - * \param rounding "UPWARD" or "TONEAREST". The rounding direction when the value - is midway between" "two representable values. + * \param rounding "UPWARD", "TONEAREST" or "TFLITE". The rounding direction + * when the value is midway between" "two representable values. * \return The sequence of Relay ops for fixed point multiplication. * \note Original compuation is scale_fp32 * quantized_tensor. To convert into @@ -193,6 +193,27 @@ static inline std::vector GetFloatVectorFromConstant(const Expr& expr) { return vals; } +static inline std::vector GetDoubleVectorFromConstant(const Expr& expr) { + const auto* n = expr.as(); + std::vector vals; + CHECK(n) << "Expr must be a constant expr - " << AsText(expr, false); + int64_t num_elems = 1; + auto shape = n->data.Shape(); + for (size_t i = 0; i < shape.size(); i++) { + num_elems *= shape[i]; + } + for (int64_t i = 0; i < num_elems; i++) { + vals.push_back(static_cast(n->data->data)[i]); + } + return vals; +} + +static inline DataType GetDataTypeFromConstant(const Expr& expr) { + const auto* n = expr.as(); + CHECK(n) << "Expr must be a constant expr - " << AsText(expr, false); + return n->tensor_type()->dtype; +} + } // namespace qnn } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index 0a51404911e0..a7e98586116a 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -503,6 +503,21 @@ static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) { return Call(op, {lhs, rhs}, Attrs(), {}); } +static inline Expr Greater(Expr lhs, Expr rhs) { + static const Op& op = Op::Get("greater"); + return Call(op, {lhs, rhs}, Attrs(), {}); +} + +static inline Expr Equal(Expr lhs, Expr rhs) { + static const Op& op = Op::Get("equal"); + return Call(op, {lhs, rhs}, Attrs(), {}); +} + +static inline Expr LogicalAnd(Expr lhs, Expr rhs) { + static const Op& op = Op::Get("logical_and"); + return Call(op, {lhs, rhs}, Attrs(), {}); +} + static inline Expr Full(Expr fill_value, Array shape, DataType dtype) { auto attrs = make_object(); attrs->shape = std::move(shape); diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 9963479fd8f7..cf3089fdf530 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -84,7 +84,7 @@ def get_real_image_object_detection(im_height, im_width): return data def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm', - out_names=None): + out_names=None, opt_level=3): """ Generic function to compile on relay and execute on tvm """ # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 try: @@ -109,7 +109,7 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target shape_dict=shape_dict, dtype_dict=dtype_dict) - with relay.build_config(opt_level=3): + with relay.build_config(opt_level=opt_level): graph, lib, params = relay.build(mod, target, params=params) ctx = tvm.context(target, 0) @@ -1978,18 +1978,15 @@ def test_forward_qnn_inception_v1_net(): with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() - # Test image. Checking the labels because the requantize implementation is different between - # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via - # labels. Also, giving a real image, instead of random inputs. - data = get_real_image(224, 224) + np.random.seed(0) + data = np.random.randint(256, size=(1, 224, 224, 3)).astype('uint8') tflite_output = run_tflite_graph(tflite_model_buf, data) tflite_predictions = np.squeeze(tflite_output) - tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') tvm_predictions = np.squeeze(tvm_output) - tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] - tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + tvm.testing.assert_allclose(tvm_predictions, tflite_predictions, + rtol=0, atol=0) def test_forward_qnn_mobilenet_v1_net(): """Test the Quantized TFLite Mobilenet V1 model.""" @@ -2000,18 +1997,15 @@ def test_forward_qnn_mobilenet_v1_net(): with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() - # Test image. Checking the labels because the requantize implementation is different between - # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via - # labels. Also, giving a real image, instead of random inputs. - data = get_real_image(224, 224) + np.random.seed(0) + data = np.random.randint(256, size=(1, 224, 224, 3)).astype('uint8') tflite_output = run_tflite_graph(tflite_model_buf, data) - tflite_predictions = np.squeeze(tflite_output) - tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] + tflite_predictions = np.squeeze(tflite_output).astype('int32') tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') - tvm_predictions = np.squeeze(tvm_output) - tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] - tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + tvm_predictions = np.squeeze(tvm_output).astype('int32') + tvm.testing.assert_allclose(tvm_predictions, tflite_predictions, + rtol=0, atol=0) def test_forward_qnn_mobilenet_v2_net(): """Test the Quantized TFLite Mobilenet V2 model.""" @@ -2022,18 +2016,16 @@ def test_forward_qnn_mobilenet_v2_net(): with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() - # Test image. Checking the labels because the requantize implementation is different between - # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via - # labels. Also, giving a real image, instead of random inputs. - data = get_real_image(224, 224) + np.random.seed(43) + # TODO: np.random.seed(43) setting py3 + data = np.random.randint(256, size=(1, 224, 224, 3)).astype('uint8') tflite_output = run_tflite_graph(tflite_model_buf, data) - tflite_predictions = np.squeeze(tflite_output) - tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] - tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') - tvm_predictions = np.squeeze(tvm_output) - tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] - tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + tflite_predictions = np.squeeze(tflite_output).astype('int32') + tvm_output = run_tvm_graph(tflite_model_buf, data, 'input', opt_level=2) + tvm_predictions = np.squeeze(tvm_output).astype('int32') + tvm.testing.assert_allclose(tvm_predictions, tflite_predictions, + rtol=0, atol=0) ####################################################################### # Mobilenet V3 Quantized diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index cf9d2d43eb12..5ffcc08f7b94 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -818,8 +818,18 @@ def test_pool2d(): _test_pool2d(relay.nn.max_pool2d, np.max, pool_size=2, strides=2, padding=0) _test_pool2d(relay.nn.avg_pool2d, np.mean) _test_pool2d(relay.nn.avg_pool2d, np.mean, pool_size=2, strides=2, padding=0) - _test_pool2d_int(relay.nn.avg_pool2d, np.mean, 'int32') - _test_pool2d_int(relay.nn.avg_pool2d, np.mean, 'uint16') + + def mean_integer(int_array, axis, keepdims=False): + sum_array = np.sum(int_array, axis=axis, keepdims=keepdims) + input_shape = np.shape(int_array) + kernel_size = 1 + for a in axis: + kernel_size = kernel_size * input_shape[a] + sum_array += kernel_size // 2 + return sum_array // kernel_size + + _test_pool2d_int(relay.nn.avg_pool2d, mean_integer, 'int32') + _test_pool2d_int(relay.nn.avg_pool2d, mean_integer, 'uint16') _test_global_pool2d(relay.nn.global_max_pool2d, np.max) _test_global_pool2d(relay.nn.global_avg_pool2d, np.mean) diff --git a/tests/python/relay/test_op_qnn_concatenate.py b/tests/python/relay/test_op_qnn_concatenate.py index fb60e9805206..519200799cda 100644 --- a/tests/python/relay/test_op_qnn_concatenate.py +++ b/tests/python/relay/test_op_qnn_concatenate.py @@ -27,8 +27,8 @@ def test_same_io_qnn_params(): axis = 0 x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype) y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype) - x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32') - y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32') + x_scale = relay.const(2.933666e-8, 'float32') + y_scale = relay.const(2.933666e-8, 'float32') zero = relay.const(0, 'int32') x = relay.var("x", shape=(1, 64), dtype=data_dtype) @@ -57,8 +57,8 @@ def test_different_io_qnn_params(): x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype) y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype) - x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32') - y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32') + x_scale = relay.const(2.933666e-8, 'float32') + y_scale = relay.const(2.933666e-8, 'float32') x_zero_point = relay.const(3, 'int32') y_zero_point = relay.const(4, 'int32') @@ -88,8 +88,8 @@ def test_few_same_io_qnn_params(): x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype) y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype) - x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32') - y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32') + x_scale = relay.const(2.933666e-8, 'float32') + y_scale = relay.const(2.933666e-8, 'float32') x_zero_point = relay.const(0, 'int32') y_zero_point = relay.const(1, 'int32') @@ -119,8 +119,8 @@ def test_same_i_qnn_params(): x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype) y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype) - x_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32') - y_scale = relay.const((62 + 64) / (np.power(2, 32) - 1.0), 'float32') + x_scale = relay.const(2.933666e-8, 'float32') + y_scale = relay.const(2.933666e-8, 'float32') x_zero_point = relay.const(0, 'int32') y_zero_point = relay.const(0, 'int32') diff --git a/tests/python/relay/test_op_qnn_requantize.py b/tests/python/relay/test_op_qnn_requantize.py index 81233972cb28..4ffbe3979d14 100644 --- a/tests/python/relay/test_op_qnn_requantize.py +++ b/tests/python/relay/test_op_qnn_requantize.py @@ -21,7 +21,7 @@ from tvm import relay from tvm.contrib import graph_runtime -roundings = ["UPWARD", "TONEAREST"] +roundings = ["UPWARD", "TONEAREST", "TFLITE"] def verify(mod, goldens): with relay.build_config(opt_level=3): @@ -90,7 +90,10 @@ def test_downscale(): # Try positive values # 8 corresponds to 0.5, resulting in 1 golden_data = np.arange(0, 32, 1).astype('int32') - golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + if rounding == "TFLITE": + golden_output = np.repeat([0, 1, 2], [7, 16, 9]) + else: + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) verify(mod, (golden_data, golden_output)) # Try negative values @@ -113,8 +116,12 @@ def test_downscale(): # Try positive values # 2I corresponds to 0.5, resulting in 1 golden_data = np.arange(0, 32, 1).astype('int32') - golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], - [2, 4, 4, 4, 4, 4, 4, 4, 2]) + if rounding == "TFLITE": + golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], + [1, 4, 4, 4, 4, 4, 4, 4, 3]) + else: + golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], + [2, 4, 4, 4, 4, 4, 4, 4, 2]) verify(mod, (golden_data, golden_output)) # Try negative values @@ -139,7 +146,10 @@ def test_downscale(): # Try positive values # 8 corresponds to 0.5, resulting in 1 golden_data = np.arange(0, 32, 1).astype('int32') - golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + if rounding == "TFLITE": + golden_output = np.repeat([0, 1, 2], [7, 16, 9]) + else: + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) verify(mod, (golden_data, golden_output)) # Try uint8 in_dtyope and uint8 out_dtype @@ -153,7 +163,10 @@ def test_downscale(): # Try positive values # 8 corresponds to 0.5, resulting in 1 golden_data = np.arange(0, 32, 1).astype('int32') - golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + if rounding == "TFLITE": + golden_output = np.repeat([0, 1, 2], [7, 16, 9]) + else: + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) verify(mod, (golden_data, golden_output)) def test_upscale(): @@ -214,7 +227,10 @@ def test_zero_point(): # Try positive values # 8 corresponds to 0.5, resulting in 1 golden_data = np.arange(0, 32, 1).astype('int32') - golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + if rounding == "TFLITE": + golden_output = np.repeat([0, 1, 2], [7, 16, 9]) + else: + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) golden_output = np.add(1, golden_output) verify(mod, (golden_data, golden_output)) @@ -240,7 +256,10 @@ def test_zero_point(): # Try positive values golden_data = np.arange(32, 64, 1).astype('int32') - golden_output = np.repeat([2, 3, 4], [8, 16, 8]) + if rounding == "TFLITE": + golden_output = np.repeat([2, 3, 4], [7, 16, 9]) + else: + golden_output = np.repeat([2, 3, 4], [8, 16, 8]) golden_output = np.subtract(golden_output, 1) verify(mod, (golden_data, golden_output)) @@ -284,17 +303,29 @@ def test_per_channel_same_scale(): def test_per_channel_different_scale(): # Have same scales, everything within range - golden_data = np.arange(-5, 5, 1).astype('int32').reshape((5,2)) - golden_output = np.array([-5, -2, -3, -1, -1, 0, 1, 1, 3, 2]).reshape((5, 2)) - + golden_data = np.arange(-32, 32, 1).astype('int32').reshape((32,2)) + for rounding in roundings: - mod = get_mod(data_shape=(5, 2), + mod = get_mod(data_shape=(32, 2), data_dtype='int32', out_dtype="int8", - input_scale=[0.5, 0.25], - output_scale=0.5, + input_scale=[1, 32], + output_scale=16, axis=1, rounding=rounding) + + if rounding == "UPWARD": + golden_output = np.array( + [-2, -62, -2, -58, -2, -54, -2, -50, -1, -46, -1, -42, -1, -38, + -1, -34, -1, -30, -1, -26, -1, -22, -1, -18, 0, -14, 0, -10, + 0, -6, 0, -2, 0, 2, 0, 6, 0, 10, 0, 14, 1, 18, 1, 22, 1, 26, + 1, 30, 1, 34, 1, 38, 1, 42, 1, 46, 2, 50, 2, 54, 2, 58, 2, 62]).reshape(32, 2) + else: + golden_output = np.array( + [-2, -62, -2, -58, -2, -54, -2, -50, -2, -46, -1, -42, -1, -38, + -1, -34, -1, -30, -1, -26, -1, -22, -1, -18, -1, -14, 0, -10, + 0, -6, 0, -2, 0, 2, 0, 6, 0, 10, 0, 14, 1, 18, 1, 22, 1, 26, + 1, 30, 1, 34, 1, 38, 1, 42, 1, 46, 2, 50, 2, 54, 2, 58, 2, 62]).reshape(32, 2) verify(mod, (golden_data, golden_output)) # Change axis diff --git a/topi/include/topi/nn/pooling.h b/topi/include/topi/nn/pooling.h index ffc4f9856a65..0262b1e3d658 100644 --- a/topi/include/topi/nn/pooling.h +++ b/topi/include/topi/nn/pooling.h @@ -145,26 +145,54 @@ inline Tensor pool_impl(const Tensor& x, const Array& kernel_size, "tensor", "pool_sum"); // TVM compute for dividing the reduced window sum by kernel size. - return tvm::te::compute( - out_shape, - [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - if (count_include_pad) { - return div(pool_sum(indices), (kernel_height * kernel_width)); - } else { - PrimExpr h_start = output[height_axis] * stride_height - pad_top; - PrimExpr w_start = output[width_axis] * stride_width - pad_left; - PrimExpr h_end = tir::MinNode::make(h_start + kernel_height, height); - PrimExpr w_end = tir::MinNode::make(w_start + kernel_width, width); - h_start = tir::MaxNode::make(h_start, make_const(DataType::DataType::Int(32), 0)); - w_start = tir::MaxNode::make(w_start, make_const(DataType::DataType::Int(32), 0)); - PrimExpr divide_factor = tir::MaxNode::make((h_end - h_start) * (w_end - w_start), - make_const(DataType::DataType::Int(32), 1)); - return div(pool_sum(indices), divide_factor); - } - }, - "tensor", kElementWise); + if (x->dtype.code() == DataType::kInt || x->dtype.code() == DataType::kUInt) { + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + if (count_include_pad) { + PrimExpr kernel_size = kernel_height * kernel_width; + PrimExpr up_rounder = floordiv(kernel_size, 2); + return floordiv(pool_sum(indices) + up_rounder, kernel_size); + } else { + PrimExpr h_start = output[height_axis] * stride_height - pad_top; + PrimExpr w_start = output[width_axis] * stride_width - pad_left; + PrimExpr h_end = tir::MinNode::make(h_start + kernel_height, height); + PrimExpr w_end = tir::MinNode::make(w_start + kernel_width, width); + h_start = tir::MaxNode::make(h_start, make_const(DataType::DataType::Int(32), 0)); + w_start = tir::MaxNode::make(w_start, make_const(DataType::DataType::Int(32), 0)); + PrimExpr divide_factor = + tir::MaxNode::make((h_end - h_start) * (w_end - w_start), + make_const(DataType::DataType::Int(32), 1)); + PrimExpr up_rounder = floordiv(divide_factor, 2); + return floordiv(pool_sum(indices) + up_rounder, divide_factor); + } + }, + "tensor", kElementWise); + } else { + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + if (count_include_pad) { + return div(pool_sum(indices), (kernel_height * kernel_width)); + } else { + PrimExpr h_start = output[height_axis] * stride_height - pad_top; + PrimExpr w_start = output[width_axis] * stride_width - pad_left; + PrimExpr h_end = tir::MinNode::make(h_start + kernel_height, height); + PrimExpr w_end = tir::MinNode::make(w_start + kernel_width, width); + h_start = tir::MaxNode::make(h_start, make_const(DataType::DataType::Int(32), 0)); + w_start = tir::MaxNode::make(w_start, make_const(DataType::DataType::Int(32), 0)); + PrimExpr divide_factor = + tir::MaxNode::make((h_end - h_start) * (w_end - w_start), + make_const(DataType::DataType::Int(32), 1)); + return div(pool_sum(indices), divide_factor); + } + }, + "tensor", kElementWise); + } } else { LOG(ERROR) << "Unrecognized pool_type: " << pool_type; return x; @@ -526,21 +554,40 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array& output_ }, "tensor", "adaptive_pool_sum"); - return tvm::te::compute( - out_shape, - [&](const Array& output) { - Array indices; - Array reduce_axes; - std::tie(indices, reduce_axes) = get_iter_vars(output, false); + if (x->dtype.code() == DataType::kInt || x->dtype.code() == DataType::kUInt) { + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + Array reduce_axes; + std::tie(indices, reduce_axes) = get_iter_vars(output, false); + + PrimExpr divide_factor = tvm::cast(x->dtype, 1); + for (size_t i = 0; i < n_dim; ++i) { + divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent); + } - PrimExpr divide_factor = tvm::cast(x->dtype, 1); - for (size_t i = 0; i < n_dim; ++i) { - divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent); - } + PrimExpr up_rounder = div(divide_factor, 2); + return div(add(pool_sum(indices), up_rounder), divide_factor); + }, + "tensor", kElementWise); + } else { + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + Array reduce_axes; + std::tie(indices, reduce_axes) = get_iter_vars(output, false); + + PrimExpr divide_factor = tvm::cast(x->dtype, 1); + for (size_t i = 0; i < n_dim; ++i) { + divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent); + } - return div(pool_sum(indices), divide_factor); - }, - "tensor", kElementWise); + return div(pool_sum(indices), divide_factor); + }, + "tensor", kElementWise); + } } else { LOG(ERROR) << "Unrecognized pool_type: " << pool_type; return x; @@ -725,35 +772,69 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, "tensor", "pool_sum"); // TVM compute for dividing the reduced window sum by kernel size. - return tvm::te::compute( - out_shape, - [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - if (count_include_pad) { - auto kernel_size = make_const(DataType::Int(32), 1); - for (int i = 0; i < k_size; i++) { - kernel_size *= kernel[i]; + if (x->dtype.code() == DataType::kInt || x->dtype.code() == DataType::kUInt) { + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + if (count_include_pad) { + auto kernel_size = make_const(DataType::Int(32), 1); + for (int i = 0; i < k_size; i++) { + kernel_size *= kernel[i]; + } + PrimExpr up_rounder = div(kernel_size, 2); + return div(add(pool_sum(indices), up_rounder), kernel_size); + } else { + std::vector start(k_size); + std::vector end(k_size); + auto kernel_size = make_const(DataType::Int(32), 1); + for (int i = 0; i < k_size; i++) { + int ii = axis[i]; + start[i] = output[ii] * stride[i] - pad_head[i]; + end[i] = tir::MinNode::make(start[i] + kernel[i], x->shape[ii]); + start[i] = tir::MaxNode::make(start[i], make_const(DataType::Int(32), 0)); + kernel_size *= (end[i] - start[i]); + } + + PrimExpr divide_factor = + tir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1)); + PrimExpr up_rounder = div(divide_factor, 2); + return div(add(pool_sum(indices), up_rounder), divide_factor); } - return div(pool_sum(indices), kernel_size); - } else { - std::vector start(k_size); - std::vector end(k_size); - auto kernel_size = make_const(DataType::Int(32), 1); - for (int i = 0; i < k_size; i++) { - int ii = axis[i]; - start[i] = output[ii] * stride[i] - pad_head[i]; - end[i] = tir::MinNode::make(start[i] + kernel[i], x->shape[ii]); - start[i] = tir::MaxNode::make(start[i], make_const(DataType::Int(32), 0)); - kernel_size *= (end[i] - start[i]); + }, + "tensor", kElementWise); + } else { + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + if (count_include_pad) { + auto kernel_size = make_const(DataType::Int(32), 1); + for (int i = 0; i < k_size; i++) { + kernel_size *= kernel[i]; + } + return div(pool_sum(indices), kernel_size); + } else { + std::vector start(k_size); + std::vector end(k_size); + auto kernel_size = make_const(DataType::Int(32), 1); + for (int i = 0; i < k_size; i++) { + int ii = axis[i]; + start[i] = output[ii] * stride[i] - pad_head[i]; + end[i] = tir::MinNode::make(start[i] + kernel[i], x->shape[ii]); + start[i] = tir::MaxNode::make(start[i], make_const(DataType::Int(32), 0)); + kernel_size *= (end[i] - start[i]); + } + + PrimExpr divide_factor = + tir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1)); + return div(pool_sum(indices), divide_factor); } - - PrimExpr divide_factor = - tir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1)); - return div(pool_sum(indices), divide_factor); - } - }, - "tensor", kElementWise); + }, + "tensor", kElementWise); + } } else { LOG(ERROR) << "Unrecognized pool_type: " << pool_type; return x;