From ea9eb060028a69f2801cdd22e7ad90b169d1eafd Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 23 Aug 2021 14:56:53 -0600 Subject: [PATCH 01/15] WIP support per-channel quantization --- python/tvm/relay/frontend/onnx.py | 2 +- .../transform/fake_quantization_to_integer.py | 20 +++++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 9f89c5ee9476..2e8f6ec5ad7d 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -479,7 +479,7 @@ def _impl_v1(cls, inputs, attr, params): attr["dilations"] = [1] + list(attr["dilations"]) if "pads" in attr: attr["pads"] = [0, attr["pads"][0], 0, attr["pads"][1]] - + attr["channels"] = kernel_shapes[0][0] out = AttrCvt( op_name=dimension_picker("conv"), transforms={ diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index cf55c67c8083..50628358e045 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -52,6 +52,7 @@ def quantize(expr, type_map): expr.args[1], expr.args[2], out_dtype=expr.attrs.out_dtype, + axis=expr.attrs.axis, ) return [out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype)] @@ -116,7 +117,11 @@ def conv2d(expr, type_map): x_t = type_map[x] w_t = type_map[weight] conv_scale = fold_constant(x_t.scale * w_t.scale) - conv_zp = relay.const(0) + shape = list(relay.transform.InferType()(tvm.IRModule.from_expr(conv_scale))["main"].body.checked_type.shape) + if len(shape) == 0: + conv_zp = relay.const(0) + else: + conv_zp = relay.const([0] * shape[0].value) out = relay.qnn.op.conv2d( x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs ) @@ -198,15 +203,22 @@ def clip(expr, type_map): amax = expr.attrs.a_max scale = fold_constant(t.scale) z_p = fold_constant(t.zero_point) - if isinstance(scale, relay.expr.Constant) and isinstance(z_p, relay.expr.Constant): + if ( + isinstance(scale, relay.expr.Constant) + and scale.data.numpy().size == 1 + and isinstance(z_p, relay.expr.Constant) + and z_p.data.numpy().size == 1 + ): scale = scale.data.numpy().item() z_p = z_p.data.numpy().item() new_min = int(amin / scale + z_p) new_max = int(amax / scale + z_p) out = relay.op.clip(arg, new_min, new_max) else: - amin = relay.op.round(relay.op.const(amin) / scale + z_p) - amax = relay.op.round(relay.op.const(amax) / scale + z_p) + amin = relay.op.cast(relay.op.round(relay.op.const(amin) / scale), t.dtype) + z_p + amax = relay.op.cast(relay.op.round(relay.op.const(amax) / scale), t.dtype) + z_p + amin = relay.op.reshape(amin, [1, -1, 1, 1]) + amax = relay.op.reshape(amax, [1, -1, 1, 1]) out = relay.op.minimum(relay.op.maximum(arg, amin), amax) return [out, t] From f33f94fb75d2d68db64f5f3251194807ebaf7a62 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 23 Aug 2021 15:34:57 -0600 Subject: [PATCH 02/15] more WIP --- .../transform/fake_quantization_to_integer.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 50628358e045..bd504cfb96b3 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -117,11 +117,16 @@ def conv2d(expr, type_map): x_t = type_map[x] w_t = type_map[weight] conv_scale = fold_constant(x_t.scale * w_t.scale) - shape = list(relay.transform.InferType()(tvm.IRModule.from_expr(conv_scale))["main"].body.checked_type.shape) + oc_axis = attrs["kernel_layout"].find("O") + shape = list( + relay.transform.InferType()(tvm.IRModule.from_expr(conv_scale))[ + "main" + ].body.checked_type.shape + ) if len(shape) == 0: conv_zp = relay.const(0) else: - conv_zp = relay.const([0] * shape[0].value) + conv_zp = relay.const([0] * shape[oc_axis].value) out = relay.qnn.op.conv2d( x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs ) @@ -203,6 +208,10 @@ def clip(expr, type_map): amax = expr.attrs.a_max scale = fold_constant(t.scale) z_p = fold_constant(t.zero_point) + if not isinstance(amin, relay.expr.Constant): + amin = relay.op.const(amin) + if not isinstance(amax, relay.expr.Constant): + amax = relay.op.const(amax) if ( isinstance(scale, relay.expr.Constant) and scale.data.numpy().size == 1 @@ -215,8 +224,8 @@ def clip(expr, type_map): new_max = int(amax / scale + z_p) out = relay.op.clip(arg, new_min, new_max) else: - amin = relay.op.cast(relay.op.round(relay.op.const(amin) / scale), t.dtype) + z_p - amax = relay.op.cast(relay.op.round(relay.op.const(amax) / scale), t.dtype) + z_p + amin = relay.op.cast(relay.op.round(amin / scale), t.dtype) + z_p + amax = relay.op.cast(relay.op.round(amax / scale), t.dtype) + z_p amin = relay.op.reshape(amin, [1, -1, 1, 1]) amax = relay.op.reshape(amax, [1, -1, 1, 1]) out = relay.op.minimum(relay.op.maximum(arg, amin), amax) From ed047ec728382efb67b6a840977f909f19eb2442 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 25 Aug 2021 12:25:35 -0600 Subject: [PATCH 03/15] More WIP --- include/tvm/ir/affine_type.h | 9 ++- python/tvm/ir/affine_type.py | 4 +- .../transform/fake_quantization_to_integer.py | 58 +++++++++++-------- src/ir/affine_type.cc | 9 +-- .../fake_quantization_to_integer.cc | 9 ++- .../test_pass_fake_quantization_to_integer.py | 34 ++++++++++- 6 files changed, 88 insertions(+), 35 deletions(-) diff --git a/include/tvm/ir/affine_type.h b/include/tvm/ir/affine_type.h index afbe1f343bb8..34d63faa3112 100644 --- a/include/tvm/ir/affine_type.h +++ b/include/tvm/ir/affine_type.h @@ -71,17 +71,21 @@ class TensorAffineTypeNode : public AffineTypeNode { RelayExpr zero_point; /*! \brief The data type of this type */ DataType dtype; + /*! \brief The data type of this type */ + int axis; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("scale", &scale); v->Visit("zero_point", &zero_point); v->Visit("dtype", &dtype); + v->Visit("axis", &axis); } bool SEqualReduce(const TensorAffineTypeNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); return equal(scale, other->scale) && equal(zero_point, other->zero_point) && - equal(dtype, other->dtype); + equal(dtype, other->dtype) && equal(axis, other->axis); + } void SHashReduce(SHashReducer hash_reduce) const { @@ -89,6 +93,7 @@ class TensorAffineTypeNode : public AffineTypeNode { hash_reduce(scale); hash_reduce(zero_point); hash_reduce(dtype); + hash_reduce(axis); } static constexpr const char* _type_key = "TensorAffineType"; @@ -101,7 +106,7 @@ class TensorAffineTypeNode : public AffineTypeNode { */ class TensorAffineType : public AffineType { public: - TVM_DLL TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype); + TVM_DLL TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype, int axis); TVM_DEFINE_OBJECT_REF_METHODS(TensorAffineType, AffineType, TensorAffineTypeNode); }; diff --git a/python/tvm/ir/affine_type.py b/python/tvm/ir/affine_type.py index a1ce08017b1b..0d6ba7122e08 100644 --- a/python/tvm/ir/affine_type.py +++ b/python/tvm/ir/affine_type.py @@ -50,8 +50,8 @@ class TensorAffineType(AffineType): The content data type. """ - def __init__(self, scale, zero_point, dtype): - self.__init_handle_by_constructor__(_ffi_api.TensorAffineType, scale, zero_point, dtype) + def __init__(self, scale, zero_point, dtype, axis=-1): + self.__init_handle_by_constructor__(_ffi_api.TensorAffineType, scale, zero_point, dtype, axis) @tvm._ffi.register_object("TupleAffineType") diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index bd504cfb96b3..7115575397a8 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -18,12 +18,18 @@ import tvm from tvm import relay from tvm.ir import TensorAffineType, TupleAffineType +from tvm.tir import bijective_layout from ..op import register_fake_quantization_to_integer def fold_constant(expr): return relay.transform.FoldConstantExpr(expr, tvm.IRModule()) +def get_zeros(scale): + return fold_constant(relay.op.cast(relay.op.zeros_like(scale), "int32")) + +def infer_shape(expr): + return relay.transform.InferType()(tvm.IRModule.from_expr(expr))["main"].body.checked_type.shape @register_fake_quantization_to_integer("qnn.dequantize") def dequantize(expr, type_map): @@ -52,9 +58,9 @@ def quantize(expr, type_map): expr.args[1], expr.args[2], out_dtype=expr.attrs.out_dtype, - axis=expr.attrs.axis, + axis=t.axis, ) - return [out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype)] + return [out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype, expr.attrs.axis)] def register_unary_identity(op_name): @@ -103,6 +109,7 @@ def bias_add(expr, type_map): in_scale, in_zero_point, out_dtype=x_t.dtype, + axis=0, ) out = relay.op.nn.bias_add(x, b, **expr.attrs) return [out, x_t] @@ -117,20 +124,14 @@ def conv2d(expr, type_map): x_t = type_map[x] w_t = type_map[weight] conv_scale = fold_constant(x_t.scale * w_t.scale) - oc_axis = attrs["kernel_layout"].find("O") - shape = list( - relay.transform.InferType()(tvm.IRModule.from_expr(conv_scale))[ - "main" - ].body.checked_type.shape - ) - if len(shape) == 0: - conv_zp = relay.const(0) - else: - conv_zp = relay.const([0] * shape[oc_axis].value) + conv_zp = get_zeros(conv_scale) out = relay.qnn.op.conv2d( x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs ) - return [out, TensorAffineType(conv_scale, conv_zp, out.attrs.out_dtype)] + scale_shape = infer_shape(conv_scale) + out_layout = attrs["out_layout"] if attrs["out_layout"] != "" else attrs["data_layout"] + out_axis = tvm.tir.bijective_layout(out_layout, "NCHW").backward_index(list(range(4)))[1] + return [out, TensorAffineType(conv_scale, conv_zp, out.attrs.out_dtype, out_axis.value)] @register_fake_quantization_to_integer("nn.dense") @@ -142,11 +143,11 @@ def dense(expr, type_map): x_t = type_map[x] w_t = type_map[weight] dense_scale = fold_constant(x_t.scale * w_t.scale) - dense_zp = relay.const(0) + dense_zp = get_zeros(dense_scale) out = relay.qnn.op.dense( x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs ) - return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype)] + return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype, x_t.axis)] @register_fake_quantization_to_integer("nn.batch_matmul") @@ -158,7 +159,7 @@ def batch_matmul(expr, type_map): matmul_scale = fold_constant(x_t.scale * y_t.scale) matmul_zp = relay.const(0) out = relay.qnn.op.batch_matmul(x, y, x_t.zero_point, y_t.zero_point, x_t.scale, y_t.scale) - return [out, TensorAffineType(matmul_scale, matmul_zp, out.attrs.out_dtype)] + return [out, TensorAffineType(matmul_scale, matmul_zp, out.attrs.out_dtype, x_t.axis)] @register_fake_quantization_to_integer("concatenate") @@ -208,10 +209,6 @@ def clip(expr, type_map): amax = expr.attrs.a_max scale = fold_constant(t.scale) z_p = fold_constant(t.zero_point) - if not isinstance(amin, relay.expr.Constant): - amin = relay.op.const(amin) - if not isinstance(amax, relay.expr.Constant): - amax = relay.op.const(amax) if ( isinstance(scale, relay.expr.Constant) and scale.data.numpy().size == 1 @@ -224,11 +221,21 @@ def clip(expr, type_map): new_max = int(amax / scale + z_p) out = relay.op.clip(arg, new_min, new_max) else: - amin = relay.op.cast(relay.op.round(amin / scale), t.dtype) + z_p - amax = relay.op.cast(relay.op.round(amax / scale), t.dtype) + z_p - amin = relay.op.reshape(amin, [1, -1, 1, 1]) - amax = relay.op.reshape(amax, [1, -1, 1, 1]) + if not isinstance(amin, relay.expr.Constant): + amin = relay.op.const(amin) + if not isinstance(amax, relay.expr.Constant): + amax = relay.op.const(amax) + + scale_shape =infer_shape(scale) + if len(scale_shape)>0 and scale_shape[0] > 1: + b_shape = [1] * len(infer_shape(arg)) + b_shape[t.axis] = -1 + amin = relay.op.reshape(relay.op.broadcast_to(amin, scale_shape), b_shape) + amax = relay.op.reshape(relay.op.broadcast_to(amax, scale_shape), b_shape) + amin = relay.qnn.op.quantize(amin, scale, z_p, t.axis, t.dtype) + amax = relay.qnn.op.quantize(amax, scale, z_p, t.axis, t.dtype) out = relay.op.minimum(relay.op.maximum(arg, amin), amax) + return [out, t] @@ -252,6 +259,7 @@ def pad(expr, type_map): t.scale, t.zero_point, out_dtype=t.dtype, + axis=t.axis ) else: ## If the pad-value is a constant, we need to quantize it @@ -340,6 +348,7 @@ def binary(expr, type_map): out_t.scale, out_t.zero_point, out_dtype=out_t.dtype, + axis=out_t.axis, ) if right_t != out_t: @@ -350,6 +359,7 @@ def binary(expr, type_map): out_t.scale, out_t.zero_point, out_dtype=out_t.dtype, + axis=out_t.axis, ) out = op(left, right) return [out, out_t] diff --git a/src/ir/affine_type.cc b/src/ir/affine_type.cc index 3454b6011c9b..023c77f3b2f4 100644 --- a/src/ir/affine_type.cc +++ b/src/ir/affine_type.cc @@ -30,26 +30,27 @@ namespace tvm { using tvm::ReprPrinter; using namespace tvm::runtime; -TensorAffineType::TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype) { +TensorAffineType::TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype, int axis) { ObjectPtr n = make_object(); n->scale = std::move(scale); n->zero_point = std::move(zero_point); n->dtype = std::move(dtype); + n->axis = std::move(axis); data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(TensorAffineTypeNode); TVM_REGISTER_GLOBAL("ir.TensorAffineType") - .set_body_typed([](RelayExpr scale, RelayExpr zero_point, DataType dtype) { - return TensorAffineType(scale, zero_point, dtype); + .set_body_typed([](RelayExpr scale, RelayExpr zero_point, DataType dtype, int axis) { + return TensorAffineType(scale, zero_point, dtype, axis); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "TensorAffineType(" << node->scale << ", " << node->zero_point << ", " - << node->dtype << ")"; + << node->dtype << ", " << node->axis << ")"; }); TupleAffineType::TupleAffineType(Array types) { diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index b5f434e74c43..c47784efca2a 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -27,6 +27,7 @@ #include #include #include +#include namespace tvm { namespace relay { @@ -109,18 +110,22 @@ class SubgraphExtractor : public ExprVisitor { protected: void VisitExpr_(const CallNode* call_node) override { if (call_node->op == quantize_op_) { + const auto* attrs = call_node->attrs.as(); + ICHECK(attrs != nullptr); // Only look at arg0 for quantize VisitExpr(call_node->args[0]); // Collect type of quantize ops affine_types_.Set(GetRef(call_node), TensorAffineType(call_node->args[1], call_node->args[2], - call_node->checked_type().as()->dtype)); + attrs->out_dtype, attrs->axis)); } else if (call_node->op == dequantize_op_) { + const auto* attrs = call_node->attrs.as(); + ICHECK(attrs != nullptr); // Collect type of dequantize ops affine_types_.Set( GetRef(call_node), TensorAffineType(call_node->args[1], call_node->args[2], - call_node->args[0]->checked_type().as()->dtype)); + call_node->args[0]->checked_type().as()->dtype, attrs->axis)); } else { // run normally on everything else. ExprVisitor::VisitExpr_(call_node); diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 2bc2e4e635f0..28afdce1d27b 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -34,7 +34,7 @@ def compare_fq_to_int(expr, args, allow_rounding_error=False): .evaluate()(*args) .numpy() ) - + print(mod_int) result_int = ( relay.create_executor("vm", mod=mod_int, device=tvm.cpu(), target="llvm") .evaluate()(*args) @@ -66,6 +66,25 @@ def test_fake_quantize_conv(): compare_fq_to_int(op, [x_np, w_np]) +def test_fake_quantize_conv_per_channel(): + for out_dtype in ["int8", "uint8"]: + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") + one = relay.const([1.0]*16) + zero = relay.const([0]*16) + + op = relay.op.nn.conv2d( + relay.qnn.op.dequantize(x, relay.const(2.0), relay.const(0)), + relay.qnn.op.dequantize(w, relay.const(np.random.random([16]).astype("float32")), zero, axis=0), + ) + op = relay.qnn.op.quantize(op, relay.const(1.0), relay.const(0), out_dtype=out_dtype) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") + + compare_fq_to_int(op, [x_np, w_np]) + + def test_fake_quantize_dense(): for out_dtype in ["int8", "uint8"]: x = relay.var("x", shape=[128, 64], dtype="int8") @@ -318,6 +337,19 @@ def test_fake_quantize_clip(): compare_fq_to_int(op, [x_np]) +def test_fake_quantize_clip_per_channel(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="uint8") + + x = relay.qnn.op.dequantize(x, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), axis=1) + op = relay.op.clip(x, 0, 6) + op = relay.qnn.op.quantize(op, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), out_dtype="uint8", axis=1) + + x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8") + + compare_fq_to_int(op, [x_np]) + + + @pytest.mark.parametrize( "operator", [relay.op.add, relay.op.multiply, relay.op.subtract, relay.op.minimum, relay.op.maximum], From 8e3dbb02272a73c4407d72a7f88a3aba5f8dabc8 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 25 Aug 2021 15:04:01 -0600 Subject: [PATCH 04/15] fix issue with per-channel bias_add --- .../transform/fake_quantization_to_integer.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 7115575397a8..4216de6329ae 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -25,12 +25,15 @@ def fold_constant(expr): return relay.transform.FoldConstantExpr(expr, tvm.IRModule()) + def get_zeros(scale): return fold_constant(relay.op.cast(relay.op.zeros_like(scale), "int32")) + def infer_shape(expr): return relay.transform.InferType()(tvm.IRModule.from_expr(expr))["main"].body.checked_type.shape + @register_fake_quantization_to_integer("qnn.dequantize") def dequantize(expr, type_map): """Remove dequantize op""" @@ -60,7 +63,11 @@ def quantize(expr, type_map): out_dtype=expr.attrs.out_dtype, axis=t.axis, ) - return [out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype, expr.attrs.axis)] + print(infer_shape(out)) + return [ + out, + TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype, expr.attrs.axis), + ] def register_unary_identity(op_name): @@ -101,7 +108,11 @@ def bias_add(expr, type_map): b_t = type_map[b] in_scale = fold_constant(x_t.scale) in_zero_point = fold_constant(x_t.zero_point) - if not tvm.ir.structural_equal(x_t, b_t): + if not ( + tvm.ir.structural_equal(x_t.scale, b_t.scale) + and tvm.ir.structural_equal(x_t.zero_point, b_t.zero_point) + and tvm.ir.structural_equal(x_t.dtype, b_t.dtype) + ): b = relay.qnn.op.requantize( b, b_t.scale, @@ -111,6 +122,7 @@ def bias_add(expr, type_map): out_dtype=x_t.dtype, axis=0, ) + print(infer_shape(b)) out = relay.op.nn.bias_add(x, b, **expr.attrs) return [out, x_t] @@ -226,8 +238,8 @@ def clip(expr, type_map): if not isinstance(amax, relay.expr.Constant): amax = relay.op.const(amax) - scale_shape =infer_shape(scale) - if len(scale_shape)>0 and scale_shape[0] > 1: + scale_shape = infer_shape(scale) + if len(scale_shape) > 0 and scale_shape[0] > 1: b_shape = [1] * len(infer_shape(arg)) b_shape[t.axis] = -1 amin = relay.op.reshape(relay.op.broadcast_to(amin, scale_shape), b_shape) @@ -252,6 +264,7 @@ def pad(expr, type_map): ## and we need to make sure it's affine type matches the arg pad_t = type_map[pad_value] if not tvm.ir.structural_equal(t, pad_t): + print("pad", t, pad_t) pad_value = relay.qnn.op.requantize( pad_value, pad_t.scale, @@ -259,7 +272,7 @@ def pad(expr, type_map): t.scale, t.zero_point, out_dtype=t.dtype, - axis=t.axis + axis=t.axis, ) else: ## If the pad-value is a constant, we need to quantize it From 0327dabc84a96a9fb005f2a176eb2812dd1ad3b6 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 25 Aug 2021 15:33:16 -0700 Subject: [PATCH 05/15] Fix fake quantize tests (#4) * Fixed fake quantize issues. * Formatting. * Cleanup unused imports * Fix real int8 tests. --- include/tvm/ir/affine_type.h | 1 - python/tvm/ir/affine_type.py | 4 +- python/tvm/relay/qnn/op/legalizations.py | 19 ++++++++-- .../transform/fake_quantization_to_integer.py | 2 +- src/ir/affine_type.cc | 3 +- src/relay/qnn/op/convolution.cc | 11 +++--- src/relay/qnn/op/requantize.cc | 6 ++- .../fake_quantization_to_integer.cc | 11 +++--- .../test_pass_fake_quantization_to_integer.py | 37 +++++++++++++------ .../test_target_texture_codegen_opencl.py | 6 +-- 10 files changed, 65 insertions(+), 35 deletions(-) diff --git a/include/tvm/ir/affine_type.h b/include/tvm/ir/affine_type.h index 34d63faa3112..4ecdcae7b958 100644 --- a/include/tvm/ir/affine_type.h +++ b/include/tvm/ir/affine_type.h @@ -85,7 +85,6 @@ class TensorAffineTypeNode : public AffineTypeNode { equal->MarkGraphNode(); return equal(scale, other->scale) && equal(zero_point, other->zero_point) && equal(dtype, other->dtype) && equal(axis, other->axis); - } void SHashReduce(SHashReducer hash_reduce) const { diff --git a/python/tvm/ir/affine_type.py b/python/tvm/ir/affine_type.py index 0d6ba7122e08..852af90673aa 100644 --- a/python/tvm/ir/affine_type.py +++ b/python/tvm/ir/affine_type.py @@ -51,7 +51,9 @@ class TensorAffineType(AffineType): """ def __init__(self, scale, zero_point, dtype, axis=-1): - self.__init_handle_by_constructor__(_ffi_api.TensorAffineType, scale, zero_point, dtype, axis) + self.__init_handle_by_constructor__( + _ffi_api.TensorAffineType, scale, zero_point, dtype, axis + ) @tvm._ffi.register_object("TupleAffineType") diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 3226240fbe39..06535eca312f 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -139,11 +139,22 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): data, kernel, input_zero_point, kernel_zero_point, _, _ = inputs shift_data = relay.subtract( - relay.cast(data, dtype="int16"), relay.cast(input_zero_point, "int16") - ) - shift_kernel = relay.subtract( - relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, "int16") + relay.cast(data, dtype="int16"), relay.cast(input_zero_point, dtype="int16") ) + # If kernel zero point is a scalar we can directly subtract it. + if len(types[3].shape) == 0: + shift_kernel = relay.subtract( + relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, dtype="int16") + ) + # Otherwise it needs to be broadcast. + else: + # Determine output axis of kernel. + output_axis = tvm.tir.layout(attrs["kernel_layout"]).index_of("O") + shift_kernel = relay.nn.bias_add( + relay.cast(kernel, dtype="int16"), + relay.cast(kernel_zero_point, dtype="int16"), + output_axis, + ) new_attrs = {k: attrs[k] for k in attrs.keys()} return relay_op(shift_data, shift_kernel, **new_attrs) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 4216de6329ae..bf8322bcc916 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -63,7 +63,7 @@ def quantize(expr, type_map): out_dtype=expr.attrs.out_dtype, axis=t.axis, ) - print(infer_shape(out)) + return [ out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype, expr.attrs.axis), diff --git a/src/ir/affine_type.cc b/src/ir/affine_type.cc index 023c77f3b2f4..87235fe20ade 100644 --- a/src/ir/affine_type.cc +++ b/src/ir/affine_type.cc @@ -30,7 +30,8 @@ namespace tvm { using tvm::ReprPrinter; using namespace tvm::runtime; -TensorAffineType::TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype, int axis) { +TensorAffineType::TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype, + int axis) { ObjectPtr n = make_object(); n->scale = std::move(scale); n->zero_point = std::move(zero_point); diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index cf5266485f2e..5782f1f6b4d1 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -495,7 +495,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, * \param input_zero_point The input zero point expr. * \param param The qnn conv2d attributes. * \param out_channels The number of output channels. - * \return The sequence of Relay operatos for term3. + * \return The sequence of Relay operators for term3. * \note The term3 looks like this * * Sigma(c,r,s) zp_a * QW(k, c, r, s) @@ -625,7 +625,7 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3, * \node Lowering of the qnn.conv2d operator * A quantized tensor is represented in following manner * A = scale_a x (QA - zp_A) - * where QA is quantized tensor, scale_a and zp_A are quantizations + * where QA is quantized tensor, scale_a and zp_A are quantization * params. * * Quantized convolution will convolve two quantized tensors and returns a @@ -662,8 +662,8 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3, * a workaround, we fall back to simpler lowering using int32 conv if * the conv is dilated. We fallback also in case of grouped conv. * - * For depthwise, we can similarly unroll the computation. The intial compute is as follows - * wehere cm = channel_multiplier + * For depthwise, we can similarly unroll the computation. The initial compute is as follows + * where cm = channel_multiplier * * Qc(n, oc, oh, ow) = Sigma(r, s) (Qw(oc/m, oc%/m, r, s) - zp_w) * * (Qa(n, oc/cm, oh + r, ow + s) - zp_a) @@ -693,12 +693,13 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, Expr kernel_zero_point = new_args[3]; const auto* param = attrs.as(); ICHECK(param != nullptr); - // Assertion checks for exisiing support. + // Assertion checks for existing support. ICHECK(param->data_layout == "NCHW" || param->data_layout == "NHWC") << "qnn.conv2d supports only NCHW/NHWC input data layout."; ICHECK(param->kernel_layout == "OIHW" || param->kernel_layout == "HWIO" || param->kernel_layout == "HWOI") << "qnn.conv2d supports only OIHW/HWIO/HWOI kernel data layout."; + ICHECK(param->kernel_size.defined()) << "qnn.conv2d requires kernel size to be specified."; int batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier; std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier) = diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 46de3522061b..295b58b04793 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -136,10 +136,12 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, const Expr& output_zero_point, const RequantizeAttrs* param, const Array& input_shape, const DataType& out_dtype) { auto tensor = Cast(input_tensor, DataType::Int(32)); - // 1) Subtract the input_zero_point auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0); if (!IsEqualScalar(input_zero_point, zero_scalar)) { - tensor = Subtract(tensor, Cast(input_zero_point, DataType::Int(32))); + // Broadcast input zero point if needed. + Expr input_zero_broadcast = + ExpandBiasToMatchAxis(input_zero_point, input_shape.size(), {param->axis}); + tensor = Subtract(tensor, Cast(input_zero_broadcast, DataType::Int(32))); } // 2) If the input and output scales are same, we can skip the fixed point multiplication. Check diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index c47784efca2a..77d18d7556f2 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -26,8 +26,8 @@ #include #include #include -#include #include +#include namespace tvm { namespace relay { @@ -115,9 +115,9 @@ class SubgraphExtractor : public ExprVisitor { // Only look at arg0 for quantize VisitExpr(call_node->args[0]); // Collect type of quantize ops - affine_types_.Set(GetRef(call_node), - TensorAffineType(call_node->args[1], call_node->args[2], - attrs->out_dtype, attrs->axis)); + affine_types_.Set( + GetRef(call_node), + TensorAffineType(call_node->args[1], call_node->args[2], attrs->out_dtype, attrs->axis)); } else if (call_node->op == dequantize_op_) { const auto* attrs = call_node->attrs.as(); ICHECK(attrs != nullptr); @@ -125,7 +125,8 @@ class SubgraphExtractor : public ExprVisitor { affine_types_.Set( GetRef(call_node), TensorAffineType(call_node->args[1], call_node->args[2], - call_node->args[0]->checked_type().as()->dtype, attrs->axis)); + call_node->args[0]->checked_type().as()->dtype, + attrs->axis)); } else { // run normally on everything else. ExprVisitor::VisitExpr_(call_node); diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 28afdce1d27b..394cb07f7158 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -34,7 +34,6 @@ def compare_fq_to_int(expr, args, allow_rounding_error=False): .evaluate()(*args) .numpy() ) - print(mod_int) result_int = ( relay.create_executor("vm", mod=mod_int, device=tvm.cpu(), target="llvm") .evaluate()(*args) @@ -42,7 +41,7 @@ def compare_fq_to_int(expr, args, allow_rounding_error=False): ) if allow_rounding_error: - assert np.all(np.abs(result - result_int) <= 1) + assert np.all(np.abs(result.astype("int32") - result_int.astype("int32")) <= 1) else: assert np.array_equal(result, result_int) @@ -57,6 +56,7 @@ def test_fake_quantize_conv(): op = relay.op.nn.conv2d( relay.qnn.op.dequantize(x, relay.const(2.0), zero), relay.qnn.op.dequantize(w, relay.const(0.5), zero), + kernel_size=[5, 5], ) op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) @@ -70,19 +70,23 @@ def test_fake_quantize_conv_per_channel(): for out_dtype in ["int8", "uint8"]: x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") - one = relay.const([1.0]*16) - zero = relay.const([0]*16) + one = relay.const([1.0] * 16) + zero = relay.const([0] * 16) op = relay.op.nn.conv2d( relay.qnn.op.dequantize(x, relay.const(2.0), relay.const(0)), - relay.qnn.op.dequantize(w, relay.const(np.random.random([16]).astype("float32")), zero, axis=0), + relay.qnn.op.dequantize( + w, relay.const(np.random.random([16]).astype("float32")), zero, axis=0 + ), + kernel_size=[5, 5], + channels=16, ) op = relay.qnn.op.quantize(op, relay.const(1.0), relay.const(0), out_dtype=out_dtype) x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") - compare_fq_to_int(op, [x_np, w_np]) + compare_fq_to_int(op, [x_np, w_np], allow_rounding_error=True) def test_fake_quantize_dense(): @@ -131,7 +135,9 @@ def test_fake_transpose_quantize_conv(): x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) x = relay.transpose(x, [0, 3, 1, 2]) - op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) + op = relay.op.nn.conv2d( + x, relay.qnn.op.dequantize(w, relay.const(0.5), zero), kernel_size=[5, 5] + ) op = relay.qnn.op.quantize(op, one, zero) x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") @@ -149,7 +155,9 @@ def test_fake_transpose_quantize_conv_bias_add(): x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) x = relay.transpose(x, [0, 3, 1, 2]) - op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) + op = relay.op.nn.conv2d( + x, relay.qnn.op.dequantize(w, relay.const(0.5), zero), kernel_size=[5, 5] + ) op = relay.op.nn.bias_add(op, relay.qnn.op.dequantize(bias, one, zero)) op = relay.qnn.op.quantize(op, one, zero) @@ -170,7 +178,9 @@ def test_fake_transpose_quantize_conv_bias_add_mismatch(): x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) x = relay.transpose(x, [0, 3, 1, 2]) - op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) + op = relay.op.nn.conv2d( + x, relay.qnn.op.dequantize(w, relay.const(0.5), zero), kernel_size=[5, 5] + ) op = relay.op.nn.bias_add(op, relay.qnn.op.dequantize(bias, two, zero)) op = relay.qnn.op.quantize(op, one, zero) @@ -340,16 +350,19 @@ def test_fake_quantize_clip(): def test_fake_quantize_clip_per_channel(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="uint8") - x = relay.qnn.op.dequantize(x, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), axis=1) + x = relay.qnn.op.dequantize( + x, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), axis=1 + ) op = relay.op.clip(x, 0, 6) - op = relay.qnn.op.quantize(op, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), out_dtype="uint8", axis=1) + op = relay.qnn.op.quantize( + op, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), out_dtype="uint8", axis=1 + ) x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8") compare_fq_to_int(op, [x_np]) - @pytest.mark.parametrize( "operator", [relay.op.add, relay.op.multiply, relay.op.subtract, relay.op.minimum, relay.op.maximum], diff --git a/tests/python/unittest/test_target_texture_codegen_opencl.py b/tests/python/unittest/test_target_texture_codegen_opencl.py index 03944c85ade5..acfadc9d51ad 100644 --- a/tests/python/unittest/test_target_texture_codegen_opencl.py +++ b/tests/python/unittest/test_target_texture_codegen_opencl.py @@ -514,7 +514,7 @@ def copy_to_texture(stage): def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dtype=None): - """Convolution operator in NCHWc layout. """ + """Convolution operator in NCHWc layout.""" if out_dtype is None: out_dtype = Input.dtype @@ -694,7 +694,7 @@ def copy_to_texture(stage): def compute_conv2d_NCHWc_KCRSk_acc32(Input, Filter, stride, padding, dilation, out_dtype=None): - """Convolution operator in NCHWc layout. """ + """Convolution operator in NCHWc layout.""" if out_dtype is None: out_dtype = Input.dtype @@ -879,7 +879,7 @@ def copy_to_texture(stage): def compute_depthwise_conv2d_NCHWc_KCRSk_acc32( Input, Filter, stride, padding, dilation, out_dtype=None ): - """Depthwise convolution operator in NCHWc layout. """ + """Depthwise convolution operator in NCHWc layout.""" if out_dtype is None: out_dtype = Input.dtype assert isinstance(stride, int) or len(stride) == 2 From f12538e095e312d44774777ba2721d74c8028464 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 25 Aug 2021 16:36:46 -0600 Subject: [PATCH 06/15] Add Relu --- .../transform/fake_quantization_to_integer.py | 26 ++++++++++++----- .../test_pass_fake_quantization_to_integer.py | 28 +++++++++++++++++++ 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index bf8322bcc916..e9d3806c9e52 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -63,7 +63,7 @@ def quantize(expr, type_map): out_dtype=expr.attrs.out_dtype, axis=t.axis, ) - + return [ out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype, expr.attrs.axis), @@ -122,7 +122,6 @@ def bias_add(expr, type_map): out_dtype=x_t.dtype, axis=0, ) - print(infer_shape(b)) out = relay.op.nn.bias_add(x, b, **expr.attrs) return [out, x_t] @@ -246,11 +245,25 @@ def clip(expr, type_map): amax = relay.op.reshape(relay.op.broadcast_to(amax, scale_shape), b_shape) amin = relay.qnn.op.quantize(amin, scale, z_p, t.axis, t.dtype) amax = relay.qnn.op.quantize(amax, scale, z_p, t.axis, t.dtype) - out = relay.op.minimum(relay.op.maximum(arg, amin), amax) + out = relay.op.minimum(relay.op.maximum(arg, fold_constant(amin)), fold_constant(amax)) return [out, t] +@register_fake_quantization_to_integer("nn.relu") +def relu(expr, type_map): + arg = expr.args[0] + t = type_map[arg] + scale_shape = infer_shape(t.scale) + zero = relay.const(0, dtype="float32") + if len(scale_shape) > 0 and scale_shape[0] > 1: + b_shape = [1] * len(infer_shape(arg)) + b_shape[t.axis] = -1 + zero = relay.op.reshape(relay.op.broadcast_to(zero, scale_shape), b_shape) + zero = relay.qnn.op.quantize(zero, t.scale, t.zero_point, t.axis, t.dtype) + return [relay.op.maximum(arg, fold_constant(zero)), t] + + @register_fake_quantization_to_integer("nn.pad") def pad(expr, type_map): """Rewite an nn.pad op""" @@ -264,7 +277,6 @@ def pad(expr, type_map): ## and we need to make sure it's affine type matches the arg pad_t = type_map[pad_value] if not tvm.ir.structural_equal(t, pad_t): - print("pad", t, pad_t) pad_value = relay.qnn.op.requantize( pad_value, pad_t.scale, @@ -272,7 +284,7 @@ def pad(expr, type_map): t.scale, t.zero_point, out_dtype=t.dtype, - axis=t.axis, + axis=pad_t.axis, ) else: ## If the pad-value is a constant, we need to quantize it @@ -361,7 +373,7 @@ def binary(expr, type_map): out_t.scale, out_t.zero_point, out_dtype=out_t.dtype, - axis=out_t.axis, + axis=left_t.axis, ) if right_t != out_t: @@ -372,7 +384,7 @@ def binary(expr, type_map): out_t.scale, out_t.zero_point, out_dtype=out_t.dtype, - axis=out_t.axis, + axis=right_t.axis, ) out = op(left, right) return [out, out_t] diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 394cb07f7158..e0b2b0d1e9ef 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -363,6 +363,34 @@ def test_fake_quantize_clip_per_channel(): compare_fq_to_int(op, [x_np]) +def test_fake_quantize_relu(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="uint8") + + x = relay.qnn.op.dequantize(x, relay.const(2.0), relay.const(114)) + op = relay.op.nn.relu(x) + op = relay.qnn.op.quantize(op, relay.const(2.0), relay.const(114), out_dtype="uint8") + + x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8") + + compare_fq_to_int(op, [x_np]) + + +def test_fake_quantize_relu_per_channel(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="uint8") + + x = relay.qnn.op.dequantize( + x, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), axis=1 + ) + op = relay.op.nn.relu(x) + op = relay.qnn.op.quantize( + op, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), out_dtype="uint8", axis=1 + ) + + x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8") + + compare_fq_to_int(op, [x_np]) + + @pytest.mark.parametrize( "operator", [relay.op.add, relay.op.multiply, relay.op.subtract, relay.op.minimum, relay.op.maximum], From 0a0cd11e16be5c0bedc523ce3a9a0c143aa66ed5 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 25 Aug 2021 16:23:47 -0700 Subject: [PATCH 07/15] One more little one (#5) * Fixed fake quantize issues. * Formatting. * Cleanup unused imports * Fix real int8 tests. * Fix requantize shape bug. --- src/relay/qnn/op/requantize.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 295b58b04793..25214e177b9b 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -139,8 +139,11 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0); if (!IsEqualScalar(input_zero_point, zero_scalar)) { // Broadcast input zero point if needed. - Expr input_zero_broadcast = - ExpandBiasToMatchAxis(input_zero_point, input_shape.size(), {param->axis}); + Expr input_zero_broadcast = ExpandBiasToMatchAxis(Reshape(input_zero_point, + { + -1, + }), + input_shape.size(), {param->axis}); tensor = Subtract(tensor, Cast(input_zero_broadcast, DataType::Int(32))); } From 7cf972994a18c28153c6ae87f050a523c849ace2 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 27 Aug 2021 20:08:06 -0600 Subject: [PATCH 08/15] Non-working Per-channel Dense --- .../transform/fake_quantization_to_integer.py | 2 +- src/relay/qnn/op/dense.cc | 25 +++++---- .../test_pass_fake_quantization_to_integer.py | 51 +++++++++++++++++++ 3 files changed, 67 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index e9d3806c9e52..e8a6622f2503 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -158,7 +158,7 @@ def dense(expr, type_map): out = relay.qnn.op.dense( x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs ) - return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype, x_t.axis)] + return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype, 1)] @register_fake_quantization_to_integer("nn.batch_matmul") diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index 592fa77aed77..2e967550124a 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -61,7 +61,6 @@ bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, } } ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point - ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale AssignType(types[5], DataType::Float(32), param->units, reporter); // weight_scale @@ -89,10 +88,17 @@ Expr DenseFirstTerm(const Expr& quantized_data, const Expr& quantized_kernel, return Dense(quantized_data, quantized_kernel, attrs->units, attrs->out_dtype); } -Expr DenseSecondTerm(const Expr& quantized_data, const Expr& kernel_zero_point) { +Expr DenseSecondTerm(const Expr& quantized_data, const Expr& kernel_zero_point, + const int out_dim_size) { Array axes = {1}; - return Multiply(kernel_zero_point, - Sum(Cast(quantized_data, DataType::Int(32)), axes, true, false)); + Expr reduced_t2 = Sum(Cast(quantized_data, DataType::Int(32)), axes, true, false); + Expr multiplied_t2; + if (!IsConstScalar(kernel_zero_point)) { + multiplied_t2 = Multiply(kernel_zero_point, MakeRepeat(reduced_t2, out_dim_size, 1)); + } else { + multiplied_t2 = Multiply(kernel_zero_point, reduced_t2); + } + return multiplied_t2; } Expr DenseThirdTerm(const Expr& quantized_kernel, const Expr& input_zero_point) { @@ -159,25 +165,24 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, Expr kernel_zero_point = new_args[3]; const auto in_shape = get_shape(arg_types[0]); + const auto w_shape = get_shape(arg_types[1]); const int reduction_dim_size = get_const_int(in_shape[1]); + const int out_dim_size = get_const_int(w_shape[0]); const auto* qnn_dense_attrs = attrs.as(); auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs); - auto term2 = DenseSecondTerm(quantized_data, kernel_zero_point); + auto term2 = DenseSecondTerm(quantized_data, kernel_zero_point, out_dim_size); auto term3 = DenseThirdTerm(quantized_kernel, input_zero_point); // Extract the integer zero points. - auto kernel_zero_point_int = GetScalarFromConstant(kernel_zero_point); - if (!IsConstScalar(input_zero_point)) { - if (kernel_zero_point_int == 0) { - return Subtract(term1, term3); - } + if (!IsConstScalar(input_zero_point) || !IsConstScalar(kernel_zero_point)) { auto term4 = DenseFourthTerm(input_zero_point, kernel_zero_point, reduction_dim_size); return DenseCombineTerms(term1, term2, term3, term4); } + auto kernel_zero_point_int = GetScalarFromConstant(kernel_zero_point); auto input_zero_point_int = GetScalarFromConstant(input_zero_point); // Get all the terms as described in the comments. diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index e0b2b0d1e9ef..d1c75dd854a2 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -108,6 +108,31 @@ def test_fake_quantize_dense(): compare_fq_to_int(op, [x_np, w_np]) +def test_fake_quantize_dense_per_channel(): + for out_dtype in ["int8", "uint8"]: + x = relay.var("x", shape=[128, 64], dtype="int8") + w = relay.var("w", shape=[256, 64], dtype="int8") + one = relay.const(1.0) + zero = relay.const(0) + + op = relay.op.nn.dense( + relay.qnn.op.dequantize(x, relay.const(2.0), zero), + relay.qnn.op.dequantize( + w, + relay.const(np.random.random([256]).astype("float32")), + relay.const([0] * 256), + axis=0, + ), + units=256, + ) + op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) + + x_np = np.random.randint(-128, 127, size=[128, 64], dtype="int8") + w_np = np.random.randint(-128, 127, size=[256, 64], dtype="int8") + + compare_fq_to_int(op, [x_np, w_np]) + + def test_fake_quantize_batch_matmul(): for out_dtype in ["int8", "uint8"]: x = relay.var("x", shape=[1, 128, 64], dtype="int8") @@ -168,6 +193,32 @@ def test_fake_transpose_quantize_conv_bias_add(): compare_fq_to_int(op, [x_np, w_np, bias_np]) +def test_fake_transpose_quantize_conv_bias_add_per_channel(): + x = relay.var("x", shape=[1, 224, 224, 3], dtype="int8") + w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") + bias = relay.var("bias", shape=[16], dtype="int32") + one = relay.const(1.0) + zero = relay.const(0) + w_scale = (np.random.random([16]).astype("float32") - 0.5) / 10 + 0.5 + w_zp = relay.const([0] * 16) + + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + x = relay.transpose(x, [0, 3, 1, 2]) + op = relay.op.nn.conv2d( + x, relay.qnn.op.dequantize(w, relay.const(w_scale), w_zp, axis=0), kernel_size=[5, 5] + ) + op = relay.op.nn.bias_add( + op, relay.qnn.op.dequantize(bias, relay.const(2.0 * w_scale), w_zp, axis=0) + ) + op = relay.qnn.op.quantize(op, one, zero) + + x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") + w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") + bias_np = np.random.randint(-32768, 32767, size=[16], dtype="int32") + + compare_fq_to_int(op, [x_np, w_np, bias_np], allow_rounding_error=True) + + def test_fake_transpose_quantize_conv_bias_add_mismatch(): x = relay.var("x", shape=[1, 224, 224, 3], dtype="int8") w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") From 90951ccd41d18a81090b5217419eb3f9aa65be0c Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 30 Aug 2021 08:13:35 -0700 Subject: [PATCH 09/15] Fix legalization for non spatial operators. (#6) * Fix legalization for non spatial operators. * Fix axis checks for end2end functionality. --- python/tvm/relay/qnn/op/legalizations.py | 18 ++++++++++++++++-- src/relay/qnn/op/dequantize.cc | 13 ++++++++++--- src/relay/qnn/op/quantize.cc | 13 ++++++++++--- .../test_pass_fake_quantization_to_integer.py | 2 +- 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 06535eca312f..fd3a1686f5a8 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -20,6 +20,7 @@ import tvm from tvm import relay +from tvm._ffi.base import TVMError from .. import op as reg ################################################# @@ -148,8 +149,21 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): ) # Otherwise it needs to be broadcast. else: - # Determine output axis of kernel. - output_axis = tvm.tir.layout(attrs["kernel_layout"]).index_of("O") + # Determine output axis of kernel for spatial operations. + if hasattr(attrs, "kernel_layout"): + output_axis = tvm.tir.layout(attrs["kernel_layout"]).index_of("O") + # For dense operations, broadcast to [N, K] layout. + elif isinstance(attrs, relay.op.op_attrs.DenseAttrs): + output_axis = 0 + # For matrix multiplication instead expand to [K, N] layout. + elif isinstance(attrs, relay.op.op_attrs.MatmulAttrs): + output_axis = 1 + else: + raise TVMError( + "Legalization of %s is not yet supported with per channel parameters" + % str(type(attrs)) + ) + shift_kernel = relay.nn.bias_add( relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, dtype="int16"), diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 7af5c2ac1c33..53b11bba58d8 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -54,9 +54,16 @@ bool DequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* dequantize_attrs = attrs.as(); int axis = dequantize_attrs->axis; auto rank = static_cast(data->shape.size()); - axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; - ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << dequantize_attrs->axis << " is out of range"; - ICHECK_GE(axis, 0) << "axis " << dequantize_attrs->axis << " is out of range"; + + // If zero point and scale are scalar then axis doesnt matter. + bool scale_is_scalar = (types[1].as())->shape.size() == 0; + bool zp_is_scalar = (types[2].as())->shape.size() == 0; + + if (!(scale_is_scalar && zp_is_scalar)) { + axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; + ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << dequantize_attrs->axis << " is out of range"; + ICHECK_GE(axis, 0) << "axis " << dequantize_attrs->axis << " is out of range"; + } PrimExpr axis_shape; if (rank > 0) { diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 2f1d7d8da16c..94073f5d3fdc 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -52,9 +52,16 @@ bool QuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* quantize_attrs = attrs.as(); int axis = quantize_attrs->axis; auto rank = static_cast(data->shape.size()); - axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; - ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << quantize_attrs->axis << " is out of range"; - ICHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range"; + + // If zero point and scale are scalar then axis doesnt matter. + bool scale_is_scalar = (types[1].as())->shape.size() == 0; + bool zp_is_scalar = (types[2].as())->shape.size() == 0; + + if (!(scale_is_scalar && zp_is_scalar)) { + axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; + ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << quantize_attrs->axis << " is out of range"; + ICHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range"; + } PrimExpr axis_shape; if (rank > 0) { diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index d1c75dd854a2..7ede17d07d99 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -130,7 +130,7 @@ def test_fake_quantize_dense_per_channel(): x_np = np.random.randint(-128, 127, size=[128, 64], dtype="int8") w_np = np.random.randint(-128, 127, size=[256, 64], dtype="int8") - compare_fq_to_int(op, [x_np, w_np]) + compare_fq_to_int(op, [x_np, w_np], allow_rounding_error=True) def test_fake_quantize_batch_matmul(): From b09f1d690c0bae06fe76d4d2aefcd725be8c8ee1 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 30 Aug 2021 09:31:29 -0600 Subject: [PATCH 10/15] fix axis normalization fix lint fix lint again --- python/tvm/relay/transform/fake_quantization_to_integer.py | 4 ++-- src/relay/qnn/op/dequantize.cc | 2 +- src/relay/qnn/op/quantize.cc | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index e8a6622f2503..39e0db4ad59d 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -139,9 +139,8 @@ def conv2d(expr, type_map): out = relay.qnn.op.conv2d( x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs ) - scale_shape = infer_shape(conv_scale) out_layout = attrs["out_layout"] if attrs["out_layout"] != "" else attrs["data_layout"] - out_axis = tvm.tir.bijective_layout(out_layout, "NCHW").backward_index(list(range(4)))[1] + out_axis = bijective_layout(out_layout, "NCHW").backward_index(list(range(4)))[1] return [out, TensorAffineType(conv_scale, conv_zp, out.attrs.out_dtype, out_axis.value)] @@ -252,6 +251,7 @@ def clip(expr, type_map): @register_fake_quantization_to_integer("nn.relu") def relu(expr, type_map): + """Rewrite a relu op""" arg = expr.args[0] t = type_map[arg] scale_shape = infer_shape(t.scale) diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 53b11bba58d8..c843eb3f544e 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -54,13 +54,13 @@ bool DequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* dequantize_attrs = attrs.as(); int axis = dequantize_attrs->axis; auto rank = static_cast(data->shape.size()); + axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; // If zero point and scale are scalar then axis doesnt matter. bool scale_is_scalar = (types[1].as())->shape.size() == 0; bool zp_is_scalar = (types[2].as())->shape.size() == 0; if (!(scale_is_scalar && zp_is_scalar)) { - axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << dequantize_attrs->axis << " is out of range"; ICHECK_GE(axis, 0) << "axis " << dequantize_attrs->axis << " is out of range"; } diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 94073f5d3fdc..b116eb9da034 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -52,13 +52,13 @@ bool QuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* quantize_attrs = attrs.as(); int axis = quantize_attrs->axis; auto rank = static_cast(data->shape.size()); + axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; // If zero point and scale are scalar then axis doesnt matter. bool scale_is_scalar = (types[1].as())->shape.size() == 0; bool zp_is_scalar = (types[2].as())->shape.size() == 0; if (!(scale_is_scalar && zp_is_scalar)) { - axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << quantize_attrs->axis << " is out of range"; ICHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range"; } From 4c6dc86249970e764b65040f66ddd125debef001 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 31 Aug 2021 15:47:22 -0600 Subject: [PATCH 11/15] Per channel fq2i (#8) * WIP support per-channel quantization * more WIP * More WIP * fix issue with per-channel bias_add * Fix fake quantize tests (#4) * Fixed fake quantize issues. * Formatting. * Cleanup unused imports * Fix real int8 tests. * Add Relu * One more little one (#5) * Fixed fake quantize issues. * Formatting. * Cleanup unused imports * Fix real int8 tests. * Fix requantize shape bug. * Non-working Per-channel Dense * Fix legalization for non spatial operators. (#6) * Fix legalization for non spatial operators. * Fix axis checks for end2end functionality. * fix axis normalization fix lint fix lint again * Fix bug in requantize dimension expansion. * Format. Co-authored-by: Josh Fromm --- src/relay/qnn/op/requantize.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 25214e177b9b..a7d214761b9b 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -139,11 +139,13 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0); if (!IsEqualScalar(input_zero_point, zero_scalar)) { // Broadcast input zero point if needed. + int rank = static_cast(input_shape.size()); + int axis = (param->axis < 0) ? ((rank > 0) ? rank + param->axis : 0) : param->axis; Expr input_zero_broadcast = ExpandBiasToMatchAxis(Reshape(input_zero_point, { -1, }), - input_shape.size(), {param->axis}); + rank, {axis}); tensor = Subtract(tensor, Cast(input_zero_broadcast, DataType::Int(32))); } From 607ba7c39c389352f63277e50f8c1b6333f553a2 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 3 Sep 2021 13:44:35 -0600 Subject: [PATCH 12/15] respond to review comments --- include/tvm/ir/affine_type.h | 2 +- python/tvm/relay/transform/fake_quantization_to_integer.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/tvm/ir/affine_type.h b/include/tvm/ir/affine_type.h index 4ecdcae7b958..5726e9eec1f0 100644 --- a/include/tvm/ir/affine_type.h +++ b/include/tvm/ir/affine_type.h @@ -71,7 +71,7 @@ class TensorAffineTypeNode : public AffineTypeNode { RelayExpr zero_point; /*! \brief The data type of this type */ DataType dtype; - /*! \brief The data type of this type */ + /*! \brief The axis for per-channel quantization */ int axis; void VisitAttrs(tvm::AttrVisitor* v) { diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 39e0db4ad59d..01b1135adfe9 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -255,12 +255,12 @@ def relu(expr, type_map): arg = expr.args[0] t = type_map[arg] scale_shape = infer_shape(t.scale) - zero = relay.const(0, dtype="float32") + z_p = t.zero_point if len(scale_shape) > 0 and scale_shape[0] > 1: b_shape = [1] * len(infer_shape(arg)) b_shape[t.axis] = -1 - zero = relay.op.reshape(relay.op.broadcast_to(zero, scale_shape), b_shape) - zero = relay.qnn.op.quantize(zero, t.scale, t.zero_point, t.axis, t.dtype) + z_p = relay.op.reshape(relay.op.broadcast_to(z_p, scale_shape), b_shape) + zero = relay.op.cast(z_p, t.dtype) return [relay.op.maximum(arg, fold_constant(zero)), t] From d3a80498c33af72794535022e9543d27d7361ad6 Mon Sep 17 00:00:00 2001 From: An Wang Date: Fri, 3 Sep 2021 16:11:45 -0700 Subject: [PATCH 13/15] start dtos --- python/tvm/relay/transform/fake_quantization_to_integer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 01b1135adfe9..6965a38af8d7 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -295,6 +295,12 @@ def pad(expr, type_map): return [out, t] +@register_fake_quantization_to_integer("nn.depth_to_space") +def depth_to_space(exp, type_map): + """Rewrite an nn.depth_to_space op""" + breakpoint() + + def get_binary_types(expr, type_map): """Get Affine types of a binary op's inputs and unify them""" ##Support the case where one input is quantized and the other is a constant float From 865cf6762317d7963c9b5af4994f510ce96521b6 Mon Sep 17 00:00:00 2001 From: An Wang Date: Tue, 7 Sep 2021 12:07:40 -0700 Subject: [PATCH 14/15] wip depth_to_space --- python/tvm/relay/transform/fake_quantization_to_integer.py | 5 ++++- tests/python/relay/test_pass_fake_quantization_to_integer.py | 4 ++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 6965a38af8d7..be826b17d84a 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -296,9 +296,12 @@ def pad(expr, type_map): @register_fake_quantization_to_integer("nn.depth_to_space") -def depth_to_space(exp, type_map): +def depth_to_space(expr, type_map): """Rewrite an nn.depth_to_space op""" + data = expr.args[0] breakpoint() + # out = relay.op.nn.depth_to_space(data, block_size, layout, mode) + return [] def get_binary_types(expr, type_map): diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 7ede17d07d99..450dcb76c1af 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -501,3 +501,7 @@ def test_fake_quantize_pad(): x_np = np.random.randint(-25, 25, size=[1, 383, 128], dtype="int8") compare_fq_to_int(op, [x_np]) + + +def test_fake_quantize_depth_to_space(): + \ No newline at end of file From 4d5c150ee438cda9343745dd98be95d312c43ba4 Mon Sep 17 00:00:00 2001 From: An Wang Date: Tue, 7 Sep 2021 13:07:33 -0700 Subject: [PATCH 15/15] dtos ident --- .../relay/transform/fake_quantization_to_integer.py | 10 +--------- .../relay/test_pass_fake_quantization_to_integer.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index be826b17d84a..c9ffa9264dc2 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -87,6 +87,7 @@ def identity(expr, type_map): register_unary_identity("expand_dims") register_unary_identity("nn.max_pool2d") register_unary_identity("nn.batch_flatten") +register_unary_identity("nn.depth_to_space") @register_fake_quantization_to_integer("nn.avg_pool2d") @@ -295,15 +296,6 @@ def pad(expr, type_map): return [out, t] -@register_fake_quantization_to_integer("nn.depth_to_space") -def depth_to_space(expr, type_map): - """Rewrite an nn.depth_to_space op""" - data = expr.args[0] - breakpoint() - # out = relay.op.nn.depth_to_space(data, block_size, layout, mode) - return [] - - def get_binary_types(expr, type_map): """Get Affine types of a binary op's inputs and unify them""" ##Support the case where one input is quantized and the other is a constant float diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 450dcb76c1af..3680310b4f92 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -504,4 +504,13 @@ def test_fake_quantize_pad(): def test_fake_quantize_depth_to_space(): - \ No newline at end of file + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.nn.depth_to_space(x, 4) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + compare_fq_to_int(op, [x_np])