From 208d048968d1cc909a798efaaa0c9b9076409460 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 23 Aug 2021 14:56:53 -0600 Subject: [PATCH 01/12] WIP support per-channel quantization --- python/tvm/relay/frontend/onnx.py | 2 +- .../transform/fake_quantization_to_integer.py | 20 +++++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 29221884702c..d09f74201807 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -479,7 +479,7 @@ def _impl_v1(cls, inputs, attr, params): attr["dilations"] = [1] + list(attr["dilations"]) if "pads" in attr: attr["pads"] = [0, attr["pads"][0], 0, attr["pads"][1]] - + attr["channels"] = kernel_shapes[0][0] out = AttrCvt( op_name=dimension_picker("conv"), transforms={ diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index cf55c67c8083..50628358e045 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -52,6 +52,7 @@ def quantize(expr, type_map): expr.args[1], expr.args[2], out_dtype=expr.attrs.out_dtype, + axis=expr.attrs.axis, ) return [out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype)] @@ -116,7 +117,11 @@ def conv2d(expr, type_map): x_t = type_map[x] w_t = type_map[weight] conv_scale = fold_constant(x_t.scale * w_t.scale) - conv_zp = relay.const(0) + shape = list(relay.transform.InferType()(tvm.IRModule.from_expr(conv_scale))["main"].body.checked_type.shape) + if len(shape) == 0: + conv_zp = relay.const(0) + else: + conv_zp = relay.const([0] * shape[0].value) out = relay.qnn.op.conv2d( x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs ) @@ -198,15 +203,22 @@ def clip(expr, type_map): amax = expr.attrs.a_max scale = fold_constant(t.scale) z_p = fold_constant(t.zero_point) - if isinstance(scale, relay.expr.Constant) and isinstance(z_p, relay.expr.Constant): + if ( + isinstance(scale, relay.expr.Constant) + and scale.data.numpy().size == 1 + and isinstance(z_p, relay.expr.Constant) + and z_p.data.numpy().size == 1 + ): scale = scale.data.numpy().item() z_p = z_p.data.numpy().item() new_min = int(amin / scale + z_p) new_max = int(amax / scale + z_p) out = relay.op.clip(arg, new_min, new_max) else: - amin = relay.op.round(relay.op.const(amin) / scale + z_p) - amax = relay.op.round(relay.op.const(amax) / scale + z_p) + amin = relay.op.cast(relay.op.round(relay.op.const(amin) / scale), t.dtype) + z_p + amax = relay.op.cast(relay.op.round(relay.op.const(amax) / scale), t.dtype) + z_p + amin = relay.op.reshape(amin, [1, -1, 1, 1]) + amax = relay.op.reshape(amax, [1, -1, 1, 1]) out = relay.op.minimum(relay.op.maximum(arg, amin), amax) return [out, t] From f30d698270d565811e223f2de3afb0927d5d62c9 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 23 Aug 2021 15:34:57 -0600 Subject: [PATCH 02/12] more WIP --- .../transform/fake_quantization_to_integer.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 50628358e045..bd504cfb96b3 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -117,11 +117,16 @@ def conv2d(expr, type_map): x_t = type_map[x] w_t = type_map[weight] conv_scale = fold_constant(x_t.scale * w_t.scale) - shape = list(relay.transform.InferType()(tvm.IRModule.from_expr(conv_scale))["main"].body.checked_type.shape) + oc_axis = attrs["kernel_layout"].find("O") + shape = list( + relay.transform.InferType()(tvm.IRModule.from_expr(conv_scale))[ + "main" + ].body.checked_type.shape + ) if len(shape) == 0: conv_zp = relay.const(0) else: - conv_zp = relay.const([0] * shape[0].value) + conv_zp = relay.const([0] * shape[oc_axis].value) out = relay.qnn.op.conv2d( x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs ) @@ -203,6 +208,10 @@ def clip(expr, type_map): amax = expr.attrs.a_max scale = fold_constant(t.scale) z_p = fold_constant(t.zero_point) + if not isinstance(amin, relay.expr.Constant): + amin = relay.op.const(amin) + if not isinstance(amax, relay.expr.Constant): + amax = relay.op.const(amax) if ( isinstance(scale, relay.expr.Constant) and scale.data.numpy().size == 1 @@ -215,8 +224,8 @@ def clip(expr, type_map): new_max = int(amax / scale + z_p) out = relay.op.clip(arg, new_min, new_max) else: - amin = relay.op.cast(relay.op.round(relay.op.const(amin) / scale), t.dtype) + z_p - amax = relay.op.cast(relay.op.round(relay.op.const(amax) / scale), t.dtype) + z_p + amin = relay.op.cast(relay.op.round(amin / scale), t.dtype) + z_p + amax = relay.op.cast(relay.op.round(amax / scale), t.dtype) + z_p amin = relay.op.reshape(amin, [1, -1, 1, 1]) amax = relay.op.reshape(amax, [1, -1, 1, 1]) out = relay.op.minimum(relay.op.maximum(arg, amin), amax) From f461c81aba240279da0bd39b6cf784fe09550074 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 25 Aug 2021 12:25:35 -0600 Subject: [PATCH 03/12] More WIP --- include/tvm/ir/affine_type.h | 9 ++- python/tvm/ir/affine_type.py | 4 +- .../transform/fake_quantization_to_integer.py | 58 +++++++++++-------- src/ir/affine_type.cc | 9 +-- .../fake_quantization_to_integer.cc | 9 ++- .../test_pass_fake_quantization_to_integer.py | 34 ++++++++++- 6 files changed, 88 insertions(+), 35 deletions(-) diff --git a/include/tvm/ir/affine_type.h b/include/tvm/ir/affine_type.h index afbe1f343bb8..34d63faa3112 100644 --- a/include/tvm/ir/affine_type.h +++ b/include/tvm/ir/affine_type.h @@ -71,17 +71,21 @@ class TensorAffineTypeNode : public AffineTypeNode { RelayExpr zero_point; /*! \brief The data type of this type */ DataType dtype; + /*! \brief The data type of this type */ + int axis; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("scale", &scale); v->Visit("zero_point", &zero_point); v->Visit("dtype", &dtype); + v->Visit("axis", &axis); } bool SEqualReduce(const TensorAffineTypeNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); return equal(scale, other->scale) && equal(zero_point, other->zero_point) && - equal(dtype, other->dtype); + equal(dtype, other->dtype) && equal(axis, other->axis); + } void SHashReduce(SHashReducer hash_reduce) const { @@ -89,6 +93,7 @@ class TensorAffineTypeNode : public AffineTypeNode { hash_reduce(scale); hash_reduce(zero_point); hash_reduce(dtype); + hash_reduce(axis); } static constexpr const char* _type_key = "TensorAffineType"; @@ -101,7 +106,7 @@ class TensorAffineTypeNode : public AffineTypeNode { */ class TensorAffineType : public AffineType { public: - TVM_DLL TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype); + TVM_DLL TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype, int axis); TVM_DEFINE_OBJECT_REF_METHODS(TensorAffineType, AffineType, TensorAffineTypeNode); }; diff --git a/python/tvm/ir/affine_type.py b/python/tvm/ir/affine_type.py index a1ce08017b1b..0d6ba7122e08 100644 --- a/python/tvm/ir/affine_type.py +++ b/python/tvm/ir/affine_type.py @@ -50,8 +50,8 @@ class TensorAffineType(AffineType): The content data type. """ - def __init__(self, scale, zero_point, dtype): - self.__init_handle_by_constructor__(_ffi_api.TensorAffineType, scale, zero_point, dtype) + def __init__(self, scale, zero_point, dtype, axis=-1): + self.__init_handle_by_constructor__(_ffi_api.TensorAffineType, scale, zero_point, dtype, axis) @tvm._ffi.register_object("TupleAffineType") diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index bd504cfb96b3..7115575397a8 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -18,12 +18,18 @@ import tvm from tvm import relay from tvm.ir import TensorAffineType, TupleAffineType +from tvm.tir import bijective_layout from ..op import register_fake_quantization_to_integer def fold_constant(expr): return relay.transform.FoldConstantExpr(expr, tvm.IRModule()) +def get_zeros(scale): + return fold_constant(relay.op.cast(relay.op.zeros_like(scale), "int32")) + +def infer_shape(expr): + return relay.transform.InferType()(tvm.IRModule.from_expr(expr))["main"].body.checked_type.shape @register_fake_quantization_to_integer("qnn.dequantize") def dequantize(expr, type_map): @@ -52,9 +58,9 @@ def quantize(expr, type_map): expr.args[1], expr.args[2], out_dtype=expr.attrs.out_dtype, - axis=expr.attrs.axis, + axis=t.axis, ) - return [out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype)] + return [out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype, expr.attrs.axis)] def register_unary_identity(op_name): @@ -103,6 +109,7 @@ def bias_add(expr, type_map): in_scale, in_zero_point, out_dtype=x_t.dtype, + axis=0, ) out = relay.op.nn.bias_add(x, b, **expr.attrs) return [out, x_t] @@ -117,20 +124,14 @@ def conv2d(expr, type_map): x_t = type_map[x] w_t = type_map[weight] conv_scale = fold_constant(x_t.scale * w_t.scale) - oc_axis = attrs["kernel_layout"].find("O") - shape = list( - relay.transform.InferType()(tvm.IRModule.from_expr(conv_scale))[ - "main" - ].body.checked_type.shape - ) - if len(shape) == 0: - conv_zp = relay.const(0) - else: - conv_zp = relay.const([0] * shape[oc_axis].value) + conv_zp = get_zeros(conv_scale) out = relay.qnn.op.conv2d( x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs ) - return [out, TensorAffineType(conv_scale, conv_zp, out.attrs.out_dtype)] + scale_shape = infer_shape(conv_scale) + out_layout = attrs["out_layout"] if attrs["out_layout"] != "" else attrs["data_layout"] + out_axis = tvm.tir.bijective_layout(out_layout, "NCHW").backward_index(list(range(4)))[1] + return [out, TensorAffineType(conv_scale, conv_zp, out.attrs.out_dtype, out_axis.value)] @register_fake_quantization_to_integer("nn.dense") @@ -142,11 +143,11 @@ def dense(expr, type_map): x_t = type_map[x] w_t = type_map[weight] dense_scale = fold_constant(x_t.scale * w_t.scale) - dense_zp = relay.const(0) + dense_zp = get_zeros(dense_scale) out = relay.qnn.op.dense( x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs ) - return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype)] + return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype, x_t.axis)] @register_fake_quantization_to_integer("nn.batch_matmul") @@ -158,7 +159,7 @@ def batch_matmul(expr, type_map): matmul_scale = fold_constant(x_t.scale * y_t.scale) matmul_zp = relay.const(0) out = relay.qnn.op.batch_matmul(x, y, x_t.zero_point, y_t.zero_point, x_t.scale, y_t.scale) - return [out, TensorAffineType(matmul_scale, matmul_zp, out.attrs.out_dtype)] + return [out, TensorAffineType(matmul_scale, matmul_zp, out.attrs.out_dtype, x_t.axis)] @register_fake_quantization_to_integer("concatenate") @@ -208,10 +209,6 @@ def clip(expr, type_map): amax = expr.attrs.a_max scale = fold_constant(t.scale) z_p = fold_constant(t.zero_point) - if not isinstance(amin, relay.expr.Constant): - amin = relay.op.const(amin) - if not isinstance(amax, relay.expr.Constant): - amax = relay.op.const(amax) if ( isinstance(scale, relay.expr.Constant) and scale.data.numpy().size == 1 @@ -224,11 +221,21 @@ def clip(expr, type_map): new_max = int(amax / scale + z_p) out = relay.op.clip(arg, new_min, new_max) else: - amin = relay.op.cast(relay.op.round(amin / scale), t.dtype) + z_p - amax = relay.op.cast(relay.op.round(amax / scale), t.dtype) + z_p - amin = relay.op.reshape(amin, [1, -1, 1, 1]) - amax = relay.op.reshape(amax, [1, -1, 1, 1]) + if not isinstance(amin, relay.expr.Constant): + amin = relay.op.const(amin) + if not isinstance(amax, relay.expr.Constant): + amax = relay.op.const(amax) + + scale_shape =infer_shape(scale) + if len(scale_shape)>0 and scale_shape[0] > 1: + b_shape = [1] * len(infer_shape(arg)) + b_shape[t.axis] = -1 + amin = relay.op.reshape(relay.op.broadcast_to(amin, scale_shape), b_shape) + amax = relay.op.reshape(relay.op.broadcast_to(amax, scale_shape), b_shape) + amin = relay.qnn.op.quantize(amin, scale, z_p, t.axis, t.dtype) + amax = relay.qnn.op.quantize(amax, scale, z_p, t.axis, t.dtype) out = relay.op.minimum(relay.op.maximum(arg, amin), amax) + return [out, t] @@ -252,6 +259,7 @@ def pad(expr, type_map): t.scale, t.zero_point, out_dtype=t.dtype, + axis=t.axis ) else: ## If the pad-value is a constant, we need to quantize it @@ -340,6 +348,7 @@ def binary(expr, type_map): out_t.scale, out_t.zero_point, out_dtype=out_t.dtype, + axis=out_t.axis, ) if right_t != out_t: @@ -350,6 +359,7 @@ def binary(expr, type_map): out_t.scale, out_t.zero_point, out_dtype=out_t.dtype, + axis=out_t.axis, ) out = op(left, right) return [out, out_t] diff --git a/src/ir/affine_type.cc b/src/ir/affine_type.cc index 3454b6011c9b..023c77f3b2f4 100644 --- a/src/ir/affine_type.cc +++ b/src/ir/affine_type.cc @@ -30,26 +30,27 @@ namespace tvm { using tvm::ReprPrinter; using namespace tvm::runtime; -TensorAffineType::TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype) { +TensorAffineType::TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype, int axis) { ObjectPtr n = make_object(); n->scale = std::move(scale); n->zero_point = std::move(zero_point); n->dtype = std::move(dtype); + n->axis = std::move(axis); data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(TensorAffineTypeNode); TVM_REGISTER_GLOBAL("ir.TensorAffineType") - .set_body_typed([](RelayExpr scale, RelayExpr zero_point, DataType dtype) { - return TensorAffineType(scale, zero_point, dtype); + .set_body_typed([](RelayExpr scale, RelayExpr zero_point, DataType dtype, int axis) { + return TensorAffineType(scale, zero_point, dtype, axis); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast(ref.get()); p->stream << "TensorAffineType(" << node->scale << ", " << node->zero_point << ", " - << node->dtype << ")"; + << node->dtype << ", " << node->axis << ")"; }); TupleAffineType::TupleAffineType(Array types) { diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index b5f434e74c43..c47784efca2a 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -27,6 +27,7 @@ #include #include #include +#include namespace tvm { namespace relay { @@ -109,18 +110,22 @@ class SubgraphExtractor : public ExprVisitor { protected: void VisitExpr_(const CallNode* call_node) override { if (call_node->op == quantize_op_) { + const auto* attrs = call_node->attrs.as(); + ICHECK(attrs != nullptr); // Only look at arg0 for quantize VisitExpr(call_node->args[0]); // Collect type of quantize ops affine_types_.Set(GetRef(call_node), TensorAffineType(call_node->args[1], call_node->args[2], - call_node->checked_type().as()->dtype)); + attrs->out_dtype, attrs->axis)); } else if (call_node->op == dequantize_op_) { + const auto* attrs = call_node->attrs.as(); + ICHECK(attrs != nullptr); // Collect type of dequantize ops affine_types_.Set( GetRef(call_node), TensorAffineType(call_node->args[1], call_node->args[2], - call_node->args[0]->checked_type().as()->dtype)); + call_node->args[0]->checked_type().as()->dtype, attrs->axis)); } else { // run normally on everything else. ExprVisitor::VisitExpr_(call_node); diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 2bc2e4e635f0..28afdce1d27b 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -34,7 +34,7 @@ def compare_fq_to_int(expr, args, allow_rounding_error=False): .evaluate()(*args) .numpy() ) - + print(mod_int) result_int = ( relay.create_executor("vm", mod=mod_int, device=tvm.cpu(), target="llvm") .evaluate()(*args) @@ -66,6 +66,25 @@ def test_fake_quantize_conv(): compare_fq_to_int(op, [x_np, w_np]) +def test_fake_quantize_conv_per_channel(): + for out_dtype in ["int8", "uint8"]: + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") + one = relay.const([1.0]*16) + zero = relay.const([0]*16) + + op = relay.op.nn.conv2d( + relay.qnn.op.dequantize(x, relay.const(2.0), relay.const(0)), + relay.qnn.op.dequantize(w, relay.const(np.random.random([16]).astype("float32")), zero, axis=0), + ) + op = relay.qnn.op.quantize(op, relay.const(1.0), relay.const(0), out_dtype=out_dtype) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") + + compare_fq_to_int(op, [x_np, w_np]) + + def test_fake_quantize_dense(): for out_dtype in ["int8", "uint8"]: x = relay.var("x", shape=[128, 64], dtype="int8") @@ -318,6 +337,19 @@ def test_fake_quantize_clip(): compare_fq_to_int(op, [x_np]) +def test_fake_quantize_clip_per_channel(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="uint8") + + x = relay.qnn.op.dequantize(x, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), axis=1) + op = relay.op.clip(x, 0, 6) + op = relay.qnn.op.quantize(op, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), out_dtype="uint8", axis=1) + + x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8") + + compare_fq_to_int(op, [x_np]) + + + @pytest.mark.parametrize( "operator", [relay.op.add, relay.op.multiply, relay.op.subtract, relay.op.minimum, relay.op.maximum], From 65ab701d9a27ad99532de2232b403df3ba2e46fe Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 25 Aug 2021 15:04:01 -0600 Subject: [PATCH 04/12] fix issue with per-channel bias_add --- .../transform/fake_quantization_to_integer.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 7115575397a8..4216de6329ae 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -25,12 +25,15 @@ def fold_constant(expr): return relay.transform.FoldConstantExpr(expr, tvm.IRModule()) + def get_zeros(scale): return fold_constant(relay.op.cast(relay.op.zeros_like(scale), "int32")) + def infer_shape(expr): return relay.transform.InferType()(tvm.IRModule.from_expr(expr))["main"].body.checked_type.shape + @register_fake_quantization_to_integer("qnn.dequantize") def dequantize(expr, type_map): """Remove dequantize op""" @@ -60,7 +63,11 @@ def quantize(expr, type_map): out_dtype=expr.attrs.out_dtype, axis=t.axis, ) - return [out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype, expr.attrs.axis)] + print(infer_shape(out)) + return [ + out, + TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype, expr.attrs.axis), + ] def register_unary_identity(op_name): @@ -101,7 +108,11 @@ def bias_add(expr, type_map): b_t = type_map[b] in_scale = fold_constant(x_t.scale) in_zero_point = fold_constant(x_t.zero_point) - if not tvm.ir.structural_equal(x_t, b_t): + if not ( + tvm.ir.structural_equal(x_t.scale, b_t.scale) + and tvm.ir.structural_equal(x_t.zero_point, b_t.zero_point) + and tvm.ir.structural_equal(x_t.dtype, b_t.dtype) + ): b = relay.qnn.op.requantize( b, b_t.scale, @@ -111,6 +122,7 @@ def bias_add(expr, type_map): out_dtype=x_t.dtype, axis=0, ) + print(infer_shape(b)) out = relay.op.nn.bias_add(x, b, **expr.attrs) return [out, x_t] @@ -226,8 +238,8 @@ def clip(expr, type_map): if not isinstance(amax, relay.expr.Constant): amax = relay.op.const(amax) - scale_shape =infer_shape(scale) - if len(scale_shape)>0 and scale_shape[0] > 1: + scale_shape = infer_shape(scale) + if len(scale_shape) > 0 and scale_shape[0] > 1: b_shape = [1] * len(infer_shape(arg)) b_shape[t.axis] = -1 amin = relay.op.reshape(relay.op.broadcast_to(amin, scale_shape), b_shape) @@ -252,6 +264,7 @@ def pad(expr, type_map): ## and we need to make sure it's affine type matches the arg pad_t = type_map[pad_value] if not tvm.ir.structural_equal(t, pad_t): + print("pad", t, pad_t) pad_value = relay.qnn.op.requantize( pad_value, pad_t.scale, @@ -259,7 +272,7 @@ def pad(expr, type_map): t.scale, t.zero_point, out_dtype=t.dtype, - axis=t.axis + axis=t.axis, ) else: ## If the pad-value is a constant, we need to quantize it From 3d07acb8217c805121548ee2fee42c17fee9297e Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 25 Aug 2021 15:33:16 -0700 Subject: [PATCH 05/12] Fix fake quantize tests (#4) * Fixed fake quantize issues. * Formatting. * Cleanup unused imports * Fix real int8 tests. --- include/tvm/ir/affine_type.h | 1 - python/tvm/ir/affine_type.py | 4 +- python/tvm/relay/qnn/op/legalizations.py | 19 ++++++++-- .../transform/fake_quantization_to_integer.py | 2 +- src/ir/affine_type.cc | 3 +- src/relay/qnn/op/convolution.cc | 11 +++--- src/relay/qnn/op/requantize.cc | 6 ++- .../fake_quantization_to_integer.cc | 11 +++--- .../test_pass_fake_quantization_to_integer.py | 37 +++++++++++++------ .../test_target_texture_codegen_opencl.py | 6 +-- 10 files changed, 65 insertions(+), 35 deletions(-) diff --git a/include/tvm/ir/affine_type.h b/include/tvm/ir/affine_type.h index 34d63faa3112..4ecdcae7b958 100644 --- a/include/tvm/ir/affine_type.h +++ b/include/tvm/ir/affine_type.h @@ -85,7 +85,6 @@ class TensorAffineTypeNode : public AffineTypeNode { equal->MarkGraphNode(); return equal(scale, other->scale) && equal(zero_point, other->zero_point) && equal(dtype, other->dtype) && equal(axis, other->axis); - } void SHashReduce(SHashReducer hash_reduce) const { diff --git a/python/tvm/ir/affine_type.py b/python/tvm/ir/affine_type.py index 0d6ba7122e08..852af90673aa 100644 --- a/python/tvm/ir/affine_type.py +++ b/python/tvm/ir/affine_type.py @@ -51,7 +51,9 @@ class TensorAffineType(AffineType): """ def __init__(self, scale, zero_point, dtype, axis=-1): - self.__init_handle_by_constructor__(_ffi_api.TensorAffineType, scale, zero_point, dtype, axis) + self.__init_handle_by_constructor__( + _ffi_api.TensorAffineType, scale, zero_point, dtype, axis + ) @tvm._ffi.register_object("TupleAffineType") diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 3226240fbe39..06535eca312f 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -139,11 +139,22 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): data, kernel, input_zero_point, kernel_zero_point, _, _ = inputs shift_data = relay.subtract( - relay.cast(data, dtype="int16"), relay.cast(input_zero_point, "int16") - ) - shift_kernel = relay.subtract( - relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, "int16") + relay.cast(data, dtype="int16"), relay.cast(input_zero_point, dtype="int16") ) + # If kernel zero point is a scalar we can directly subtract it. + if len(types[3].shape) == 0: + shift_kernel = relay.subtract( + relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, dtype="int16") + ) + # Otherwise it needs to be broadcast. + else: + # Determine output axis of kernel. + output_axis = tvm.tir.layout(attrs["kernel_layout"]).index_of("O") + shift_kernel = relay.nn.bias_add( + relay.cast(kernel, dtype="int16"), + relay.cast(kernel_zero_point, dtype="int16"), + output_axis, + ) new_attrs = {k: attrs[k] for k in attrs.keys()} return relay_op(shift_data, shift_kernel, **new_attrs) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 4216de6329ae..bf8322bcc916 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -63,7 +63,7 @@ def quantize(expr, type_map): out_dtype=expr.attrs.out_dtype, axis=t.axis, ) - print(infer_shape(out)) + return [ out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype, expr.attrs.axis), diff --git a/src/ir/affine_type.cc b/src/ir/affine_type.cc index 023c77f3b2f4..87235fe20ade 100644 --- a/src/ir/affine_type.cc +++ b/src/ir/affine_type.cc @@ -30,7 +30,8 @@ namespace tvm { using tvm::ReprPrinter; using namespace tvm::runtime; -TensorAffineType::TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype, int axis) { +TensorAffineType::TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype, + int axis) { ObjectPtr n = make_object(); n->scale = std::move(scale); n->zero_point = std::move(zero_point); diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index cf5266485f2e..5782f1f6b4d1 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -495,7 +495,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point, * \param input_zero_point The input zero point expr. * \param param The qnn conv2d attributes. * \param out_channels The number of output channels. - * \return The sequence of Relay operatos for term3. + * \return The sequence of Relay operators for term3. * \note The term3 looks like this * * Sigma(c,r,s) zp_a * QW(k, c, r, s) @@ -625,7 +625,7 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3, * \node Lowering of the qnn.conv2d operator * A quantized tensor is represented in following manner * A = scale_a x (QA - zp_A) - * where QA is quantized tensor, scale_a and zp_A are quantizations + * where QA is quantized tensor, scale_a and zp_A are quantization * params. * * Quantized convolution will convolve two quantized tensors and returns a @@ -662,8 +662,8 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3, * a workaround, we fall back to simpler lowering using int32 conv if * the conv is dilated. We fallback also in case of grouped conv. * - * For depthwise, we can similarly unroll the computation. The intial compute is as follows - * wehere cm = channel_multiplier + * For depthwise, we can similarly unroll the computation. The initial compute is as follows + * where cm = channel_multiplier * * Qc(n, oc, oh, ow) = Sigma(r, s) (Qw(oc/m, oc%/m, r, s) - zp_w) * * (Qa(n, oc/cm, oh + r, ow + s) - zp_a) @@ -693,12 +693,13 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, Expr kernel_zero_point = new_args[3]; const auto* param = attrs.as(); ICHECK(param != nullptr); - // Assertion checks for exisiing support. + // Assertion checks for existing support. ICHECK(param->data_layout == "NCHW" || param->data_layout == "NHWC") << "qnn.conv2d supports only NCHW/NHWC input data layout."; ICHECK(param->kernel_layout == "OIHW" || param->kernel_layout == "HWIO" || param->kernel_layout == "HWOI") << "qnn.conv2d supports only OIHW/HWIO/HWOI kernel data layout."; + ICHECK(param->kernel_size.defined()) << "qnn.conv2d requires kernel size to be specified."; int batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier; std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier) = diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 46de3522061b..295b58b04793 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -136,10 +136,12 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, const Expr& output_zero_point, const RequantizeAttrs* param, const Array& input_shape, const DataType& out_dtype) { auto tensor = Cast(input_tensor, DataType::Int(32)); - // 1) Subtract the input_zero_point auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0); if (!IsEqualScalar(input_zero_point, zero_scalar)) { - tensor = Subtract(tensor, Cast(input_zero_point, DataType::Int(32))); + // Broadcast input zero point if needed. + Expr input_zero_broadcast = + ExpandBiasToMatchAxis(input_zero_point, input_shape.size(), {param->axis}); + tensor = Subtract(tensor, Cast(input_zero_broadcast, DataType::Int(32))); } // 2) If the input and output scales are same, we can skip the fixed point multiplication. Check diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index c47784efca2a..77d18d7556f2 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -26,8 +26,8 @@ #include #include #include -#include #include +#include namespace tvm { namespace relay { @@ -115,9 +115,9 @@ class SubgraphExtractor : public ExprVisitor { // Only look at arg0 for quantize VisitExpr(call_node->args[0]); // Collect type of quantize ops - affine_types_.Set(GetRef(call_node), - TensorAffineType(call_node->args[1], call_node->args[2], - attrs->out_dtype, attrs->axis)); + affine_types_.Set( + GetRef(call_node), + TensorAffineType(call_node->args[1], call_node->args[2], attrs->out_dtype, attrs->axis)); } else if (call_node->op == dequantize_op_) { const auto* attrs = call_node->attrs.as(); ICHECK(attrs != nullptr); @@ -125,7 +125,8 @@ class SubgraphExtractor : public ExprVisitor { affine_types_.Set( GetRef(call_node), TensorAffineType(call_node->args[1], call_node->args[2], - call_node->args[0]->checked_type().as()->dtype, attrs->axis)); + call_node->args[0]->checked_type().as()->dtype, + attrs->axis)); } else { // run normally on everything else. ExprVisitor::VisitExpr_(call_node); diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 28afdce1d27b..394cb07f7158 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -34,7 +34,6 @@ def compare_fq_to_int(expr, args, allow_rounding_error=False): .evaluate()(*args) .numpy() ) - print(mod_int) result_int = ( relay.create_executor("vm", mod=mod_int, device=tvm.cpu(), target="llvm") .evaluate()(*args) @@ -42,7 +41,7 @@ def compare_fq_to_int(expr, args, allow_rounding_error=False): ) if allow_rounding_error: - assert np.all(np.abs(result - result_int) <= 1) + assert np.all(np.abs(result.astype("int32") - result_int.astype("int32")) <= 1) else: assert np.array_equal(result, result_int) @@ -57,6 +56,7 @@ def test_fake_quantize_conv(): op = relay.op.nn.conv2d( relay.qnn.op.dequantize(x, relay.const(2.0), zero), relay.qnn.op.dequantize(w, relay.const(0.5), zero), + kernel_size=[5, 5], ) op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) @@ -70,19 +70,23 @@ def test_fake_quantize_conv_per_channel(): for out_dtype in ["int8", "uint8"]: x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") - one = relay.const([1.0]*16) - zero = relay.const([0]*16) + one = relay.const([1.0] * 16) + zero = relay.const([0] * 16) op = relay.op.nn.conv2d( relay.qnn.op.dequantize(x, relay.const(2.0), relay.const(0)), - relay.qnn.op.dequantize(w, relay.const(np.random.random([16]).astype("float32")), zero, axis=0), + relay.qnn.op.dequantize( + w, relay.const(np.random.random([16]).astype("float32")), zero, axis=0 + ), + kernel_size=[5, 5], + channels=16, ) op = relay.qnn.op.quantize(op, relay.const(1.0), relay.const(0), out_dtype=out_dtype) x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") - compare_fq_to_int(op, [x_np, w_np]) + compare_fq_to_int(op, [x_np, w_np], allow_rounding_error=True) def test_fake_quantize_dense(): @@ -131,7 +135,9 @@ def test_fake_transpose_quantize_conv(): x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) x = relay.transpose(x, [0, 3, 1, 2]) - op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) + op = relay.op.nn.conv2d( + x, relay.qnn.op.dequantize(w, relay.const(0.5), zero), kernel_size=[5, 5] + ) op = relay.qnn.op.quantize(op, one, zero) x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") @@ -149,7 +155,9 @@ def test_fake_transpose_quantize_conv_bias_add(): x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) x = relay.transpose(x, [0, 3, 1, 2]) - op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) + op = relay.op.nn.conv2d( + x, relay.qnn.op.dequantize(w, relay.const(0.5), zero), kernel_size=[5, 5] + ) op = relay.op.nn.bias_add(op, relay.qnn.op.dequantize(bias, one, zero)) op = relay.qnn.op.quantize(op, one, zero) @@ -170,7 +178,9 @@ def test_fake_transpose_quantize_conv_bias_add_mismatch(): x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) x = relay.transpose(x, [0, 3, 1, 2]) - op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) + op = relay.op.nn.conv2d( + x, relay.qnn.op.dequantize(w, relay.const(0.5), zero), kernel_size=[5, 5] + ) op = relay.op.nn.bias_add(op, relay.qnn.op.dequantize(bias, two, zero)) op = relay.qnn.op.quantize(op, one, zero) @@ -340,16 +350,19 @@ def test_fake_quantize_clip(): def test_fake_quantize_clip_per_channel(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="uint8") - x = relay.qnn.op.dequantize(x, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), axis=1) + x = relay.qnn.op.dequantize( + x, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), axis=1 + ) op = relay.op.clip(x, 0, 6) - op = relay.qnn.op.quantize(op, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), out_dtype="uint8", axis=1) + op = relay.qnn.op.quantize( + op, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), out_dtype="uint8", axis=1 + ) x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8") compare_fq_to_int(op, [x_np]) - @pytest.mark.parametrize( "operator", [relay.op.add, relay.op.multiply, relay.op.subtract, relay.op.minimum, relay.op.maximum], diff --git a/tests/python/unittest/test_target_texture_codegen_opencl.py b/tests/python/unittest/test_target_texture_codegen_opencl.py index 03944c85ade5..acfadc9d51ad 100644 --- a/tests/python/unittest/test_target_texture_codegen_opencl.py +++ b/tests/python/unittest/test_target_texture_codegen_opencl.py @@ -514,7 +514,7 @@ def copy_to_texture(stage): def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dtype=None): - """Convolution operator in NCHWc layout. """ + """Convolution operator in NCHWc layout.""" if out_dtype is None: out_dtype = Input.dtype @@ -694,7 +694,7 @@ def copy_to_texture(stage): def compute_conv2d_NCHWc_KCRSk_acc32(Input, Filter, stride, padding, dilation, out_dtype=None): - """Convolution operator in NCHWc layout. """ + """Convolution operator in NCHWc layout.""" if out_dtype is None: out_dtype = Input.dtype @@ -879,7 +879,7 @@ def copy_to_texture(stage): def compute_depthwise_conv2d_NCHWc_KCRSk_acc32( Input, Filter, stride, padding, dilation, out_dtype=None ): - """Depthwise convolution operator in NCHWc layout. """ + """Depthwise convolution operator in NCHWc layout.""" if out_dtype is None: out_dtype = Input.dtype assert isinstance(stride, int) or len(stride) == 2 From fefca2c8ac9d96ad3e60c2bd63bb3375f566a807 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 25 Aug 2021 16:36:46 -0600 Subject: [PATCH 06/12] Add Relu --- .../transform/fake_quantization_to_integer.py | 26 ++++++++++++----- .../test_pass_fake_quantization_to_integer.py | 28 +++++++++++++++++++ 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index bf8322bcc916..e9d3806c9e52 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -63,7 +63,7 @@ def quantize(expr, type_map): out_dtype=expr.attrs.out_dtype, axis=t.axis, ) - + return [ out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype, expr.attrs.axis), @@ -122,7 +122,6 @@ def bias_add(expr, type_map): out_dtype=x_t.dtype, axis=0, ) - print(infer_shape(b)) out = relay.op.nn.bias_add(x, b, **expr.attrs) return [out, x_t] @@ -246,11 +245,25 @@ def clip(expr, type_map): amax = relay.op.reshape(relay.op.broadcast_to(amax, scale_shape), b_shape) amin = relay.qnn.op.quantize(amin, scale, z_p, t.axis, t.dtype) amax = relay.qnn.op.quantize(amax, scale, z_p, t.axis, t.dtype) - out = relay.op.minimum(relay.op.maximum(arg, amin), amax) + out = relay.op.minimum(relay.op.maximum(arg, fold_constant(amin)), fold_constant(amax)) return [out, t] +@register_fake_quantization_to_integer("nn.relu") +def relu(expr, type_map): + arg = expr.args[0] + t = type_map[arg] + scale_shape = infer_shape(t.scale) + zero = relay.const(0, dtype="float32") + if len(scale_shape) > 0 and scale_shape[0] > 1: + b_shape = [1] * len(infer_shape(arg)) + b_shape[t.axis] = -1 + zero = relay.op.reshape(relay.op.broadcast_to(zero, scale_shape), b_shape) + zero = relay.qnn.op.quantize(zero, t.scale, t.zero_point, t.axis, t.dtype) + return [relay.op.maximum(arg, fold_constant(zero)), t] + + @register_fake_quantization_to_integer("nn.pad") def pad(expr, type_map): """Rewite an nn.pad op""" @@ -264,7 +277,6 @@ def pad(expr, type_map): ## and we need to make sure it's affine type matches the arg pad_t = type_map[pad_value] if not tvm.ir.structural_equal(t, pad_t): - print("pad", t, pad_t) pad_value = relay.qnn.op.requantize( pad_value, pad_t.scale, @@ -272,7 +284,7 @@ def pad(expr, type_map): t.scale, t.zero_point, out_dtype=t.dtype, - axis=t.axis, + axis=pad_t.axis, ) else: ## If the pad-value is a constant, we need to quantize it @@ -361,7 +373,7 @@ def binary(expr, type_map): out_t.scale, out_t.zero_point, out_dtype=out_t.dtype, - axis=out_t.axis, + axis=left_t.axis, ) if right_t != out_t: @@ -372,7 +384,7 @@ def binary(expr, type_map): out_t.scale, out_t.zero_point, out_dtype=out_t.dtype, - axis=out_t.axis, + axis=right_t.axis, ) out = op(left, right) return [out, out_t] diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index 394cb07f7158..e0b2b0d1e9ef 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -363,6 +363,34 @@ def test_fake_quantize_clip_per_channel(): compare_fq_to_int(op, [x_np]) +def test_fake_quantize_relu(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="uint8") + + x = relay.qnn.op.dequantize(x, relay.const(2.0), relay.const(114)) + op = relay.op.nn.relu(x) + op = relay.qnn.op.quantize(op, relay.const(2.0), relay.const(114), out_dtype="uint8") + + x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8") + + compare_fq_to_int(op, [x_np]) + + +def test_fake_quantize_relu_per_channel(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="uint8") + + x = relay.qnn.op.dequantize( + x, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), axis=1 + ) + op = relay.op.nn.relu(x) + op = relay.qnn.op.quantize( + op, relay.const([1.0, 2.0, 3.0]), relay.const([96, 114, 128]), out_dtype="uint8", axis=1 + ) + + x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8") + + compare_fq_to_int(op, [x_np]) + + @pytest.mark.parametrize( "operator", [relay.op.add, relay.op.multiply, relay.op.subtract, relay.op.minimum, relay.op.maximum], From c4c746be5419348719deeca8133ba6a5ff679270 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 25 Aug 2021 16:23:47 -0700 Subject: [PATCH 07/12] One more little one (#5) * Fixed fake quantize issues. * Formatting. * Cleanup unused imports * Fix real int8 tests. * Fix requantize shape bug. --- src/relay/qnn/op/requantize.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 295b58b04793..25214e177b9b 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -139,8 +139,11 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0); if (!IsEqualScalar(input_zero_point, zero_scalar)) { // Broadcast input zero point if needed. - Expr input_zero_broadcast = - ExpandBiasToMatchAxis(input_zero_point, input_shape.size(), {param->axis}); + Expr input_zero_broadcast = ExpandBiasToMatchAxis(Reshape(input_zero_point, + { + -1, + }), + input_shape.size(), {param->axis}); tensor = Subtract(tensor, Cast(input_zero_broadcast, DataType::Int(32))); } From fffc5499823cdb0cb79434e7a9e3450eec7aaf09 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 27 Aug 2021 20:08:06 -0600 Subject: [PATCH 08/12] Non-working Per-channel Dense --- .../transform/fake_quantization_to_integer.py | 2 +- src/relay/qnn/op/dense.cc | 25 +++++---- .../test_pass_fake_quantization_to_integer.py | 51 +++++++++++++++++++ 3 files changed, 67 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index e9d3806c9e52..e8a6622f2503 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -158,7 +158,7 @@ def dense(expr, type_map): out = relay.qnn.op.dense( x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs ) - return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype, x_t.axis)] + return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype, 1)] @register_fake_quantization_to_integer("nn.batch_matmul") diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index 592fa77aed77..2e967550124a 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -61,7 +61,6 @@ bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, } } ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point - ICHECK(IsScalarType(types[3], DataType::Int(32))); // weight_zero_point ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale AssignType(types[5], DataType::Float(32), param->units, reporter); // weight_scale @@ -89,10 +88,17 @@ Expr DenseFirstTerm(const Expr& quantized_data, const Expr& quantized_kernel, return Dense(quantized_data, quantized_kernel, attrs->units, attrs->out_dtype); } -Expr DenseSecondTerm(const Expr& quantized_data, const Expr& kernel_zero_point) { +Expr DenseSecondTerm(const Expr& quantized_data, const Expr& kernel_zero_point, + const int out_dim_size) { Array axes = {1}; - return Multiply(kernel_zero_point, - Sum(Cast(quantized_data, DataType::Int(32)), axes, true, false)); + Expr reduced_t2 = Sum(Cast(quantized_data, DataType::Int(32)), axes, true, false); + Expr multiplied_t2; + if (!IsConstScalar(kernel_zero_point)) { + multiplied_t2 = Multiply(kernel_zero_point, MakeRepeat(reduced_t2, out_dim_size, 1)); + } else { + multiplied_t2 = Multiply(kernel_zero_point, reduced_t2); + } + return multiplied_t2; } Expr DenseThirdTerm(const Expr& quantized_kernel, const Expr& input_zero_point) { @@ -159,25 +165,24 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, Expr kernel_zero_point = new_args[3]; const auto in_shape = get_shape(arg_types[0]); + const auto w_shape = get_shape(arg_types[1]); const int reduction_dim_size = get_const_int(in_shape[1]); + const int out_dim_size = get_const_int(w_shape[0]); const auto* qnn_dense_attrs = attrs.as(); auto term1 = DenseFirstTerm(quantized_data, quantized_kernel, qnn_dense_attrs); - auto term2 = DenseSecondTerm(quantized_data, kernel_zero_point); + auto term2 = DenseSecondTerm(quantized_data, kernel_zero_point, out_dim_size); auto term3 = DenseThirdTerm(quantized_kernel, input_zero_point); // Extract the integer zero points. - auto kernel_zero_point_int = GetScalarFromConstant(kernel_zero_point); - if (!IsConstScalar(input_zero_point)) { - if (kernel_zero_point_int == 0) { - return Subtract(term1, term3); - } + if (!IsConstScalar(input_zero_point) || !IsConstScalar(kernel_zero_point)) { auto term4 = DenseFourthTerm(input_zero_point, kernel_zero_point, reduction_dim_size); return DenseCombineTerms(term1, term2, term3, term4); } + auto kernel_zero_point_int = GetScalarFromConstant(kernel_zero_point); auto input_zero_point_int = GetScalarFromConstant(input_zero_point); // Get all the terms as described in the comments. diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index e0b2b0d1e9ef..d1c75dd854a2 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -108,6 +108,31 @@ def test_fake_quantize_dense(): compare_fq_to_int(op, [x_np, w_np]) +def test_fake_quantize_dense_per_channel(): + for out_dtype in ["int8", "uint8"]: + x = relay.var("x", shape=[128, 64], dtype="int8") + w = relay.var("w", shape=[256, 64], dtype="int8") + one = relay.const(1.0) + zero = relay.const(0) + + op = relay.op.nn.dense( + relay.qnn.op.dequantize(x, relay.const(2.0), zero), + relay.qnn.op.dequantize( + w, + relay.const(np.random.random([256]).astype("float32")), + relay.const([0] * 256), + axis=0, + ), + units=256, + ) + op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) + + x_np = np.random.randint(-128, 127, size=[128, 64], dtype="int8") + w_np = np.random.randint(-128, 127, size=[256, 64], dtype="int8") + + compare_fq_to_int(op, [x_np, w_np]) + + def test_fake_quantize_batch_matmul(): for out_dtype in ["int8", "uint8"]: x = relay.var("x", shape=[1, 128, 64], dtype="int8") @@ -168,6 +193,32 @@ def test_fake_transpose_quantize_conv_bias_add(): compare_fq_to_int(op, [x_np, w_np, bias_np]) +def test_fake_transpose_quantize_conv_bias_add_per_channel(): + x = relay.var("x", shape=[1, 224, 224, 3], dtype="int8") + w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") + bias = relay.var("bias", shape=[16], dtype="int32") + one = relay.const(1.0) + zero = relay.const(0) + w_scale = (np.random.random([16]).astype("float32") - 0.5) / 10 + 0.5 + w_zp = relay.const([0] * 16) + + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + x = relay.transpose(x, [0, 3, 1, 2]) + op = relay.op.nn.conv2d( + x, relay.qnn.op.dequantize(w, relay.const(w_scale), w_zp, axis=0), kernel_size=[5, 5] + ) + op = relay.op.nn.bias_add( + op, relay.qnn.op.dequantize(bias, relay.const(2.0 * w_scale), w_zp, axis=0) + ) + op = relay.qnn.op.quantize(op, one, zero) + + x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") + w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") + bias_np = np.random.randint(-32768, 32767, size=[16], dtype="int32") + + compare_fq_to_int(op, [x_np, w_np, bias_np], allow_rounding_error=True) + + def test_fake_transpose_quantize_conv_bias_add_mismatch(): x = relay.var("x", shape=[1, 224, 224, 3], dtype="int8") w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") From 4b819044a01b8b850945f5d3094bc59b538660dc Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 30 Aug 2021 08:13:35 -0700 Subject: [PATCH 09/12] Fix legalization for non spatial operators. (#6) * Fix legalization for non spatial operators. * Fix axis checks for end2end functionality. --- python/tvm/relay/qnn/op/legalizations.py | 18 ++++++++++++++++-- src/relay/qnn/op/dequantize.cc | 13 ++++++++++--- src/relay/qnn/op/quantize.cc | 13 ++++++++++--- .../test_pass_fake_quantization_to_integer.py | 2 +- 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index 06535eca312f..fd3a1686f5a8 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -20,6 +20,7 @@ import tvm from tvm import relay +from tvm._ffi.base import TVMError from .. import op as reg ################################################# @@ -148,8 +149,21 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): ) # Otherwise it needs to be broadcast. else: - # Determine output axis of kernel. - output_axis = tvm.tir.layout(attrs["kernel_layout"]).index_of("O") + # Determine output axis of kernel for spatial operations. + if hasattr(attrs, "kernel_layout"): + output_axis = tvm.tir.layout(attrs["kernel_layout"]).index_of("O") + # For dense operations, broadcast to [N, K] layout. + elif isinstance(attrs, relay.op.op_attrs.DenseAttrs): + output_axis = 0 + # For matrix multiplication instead expand to [K, N] layout. + elif isinstance(attrs, relay.op.op_attrs.MatmulAttrs): + output_axis = 1 + else: + raise TVMError( + "Legalization of %s is not yet supported with per channel parameters" + % str(type(attrs)) + ) + shift_kernel = relay.nn.bias_add( relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, dtype="int16"), diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 7af5c2ac1c33..53b11bba58d8 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -54,9 +54,16 @@ bool DequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* dequantize_attrs = attrs.as(); int axis = dequantize_attrs->axis; auto rank = static_cast(data->shape.size()); - axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; - ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << dequantize_attrs->axis << " is out of range"; - ICHECK_GE(axis, 0) << "axis " << dequantize_attrs->axis << " is out of range"; + + // If zero point and scale are scalar then axis doesnt matter. + bool scale_is_scalar = (types[1].as())->shape.size() == 0; + bool zp_is_scalar = (types[2].as())->shape.size() == 0; + + if (!(scale_is_scalar && zp_is_scalar)) { + axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; + ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << dequantize_attrs->axis << " is out of range"; + ICHECK_GE(axis, 0) << "axis " << dequantize_attrs->axis << " is out of range"; + } PrimExpr axis_shape; if (rank > 0) { diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 2f1d7d8da16c..94073f5d3fdc 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -52,9 +52,16 @@ bool QuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* quantize_attrs = attrs.as(); int axis = quantize_attrs->axis; auto rank = static_cast(data->shape.size()); - axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; - ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << quantize_attrs->axis << " is out of range"; - ICHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range"; + + // If zero point and scale are scalar then axis doesnt matter. + bool scale_is_scalar = (types[1].as())->shape.size() == 0; + bool zp_is_scalar = (types[2].as())->shape.size() == 0; + + if (!(scale_is_scalar && zp_is_scalar)) { + axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; + ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << quantize_attrs->axis << " is out of range"; + ICHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range"; + } PrimExpr axis_shape; if (rank > 0) { diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index d1c75dd854a2..7ede17d07d99 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -130,7 +130,7 @@ def test_fake_quantize_dense_per_channel(): x_np = np.random.randint(-128, 127, size=[128, 64], dtype="int8") w_np = np.random.randint(-128, 127, size=[256, 64], dtype="int8") - compare_fq_to_int(op, [x_np, w_np]) + compare_fq_to_int(op, [x_np, w_np], allow_rounding_error=True) def test_fake_quantize_batch_matmul(): From 9172d4d66d6c847478907b3555f553d2fb1ba7ca Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 30 Aug 2021 09:31:29 -0600 Subject: [PATCH 10/12] fix axis normalization fix lint fix lint again --- python/tvm/relay/transform/fake_quantization_to_integer.py | 4 ++-- src/relay/qnn/op/dequantize.cc | 2 +- src/relay/qnn/op/quantize.cc | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index e8a6622f2503..39e0db4ad59d 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -139,9 +139,8 @@ def conv2d(expr, type_map): out = relay.qnn.op.conv2d( x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs ) - scale_shape = infer_shape(conv_scale) out_layout = attrs["out_layout"] if attrs["out_layout"] != "" else attrs["data_layout"] - out_axis = tvm.tir.bijective_layout(out_layout, "NCHW").backward_index(list(range(4)))[1] + out_axis = bijective_layout(out_layout, "NCHW").backward_index(list(range(4)))[1] return [out, TensorAffineType(conv_scale, conv_zp, out.attrs.out_dtype, out_axis.value)] @@ -252,6 +251,7 @@ def clip(expr, type_map): @register_fake_quantization_to_integer("nn.relu") def relu(expr, type_map): + """Rewrite a relu op""" arg = expr.args[0] t = type_map[arg] scale_shape = infer_shape(t.scale) diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 53b11bba58d8..c843eb3f544e 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -54,13 +54,13 @@ bool DequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* dequantize_attrs = attrs.as(); int axis = dequantize_attrs->axis; auto rank = static_cast(data->shape.size()); + axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; // If zero point and scale are scalar then axis doesnt matter. bool scale_is_scalar = (types[1].as())->shape.size() == 0; bool zp_is_scalar = (types[2].as())->shape.size() == 0; if (!(scale_is_scalar && zp_is_scalar)) { - axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << dequantize_attrs->axis << " is out of range"; ICHECK_GE(axis, 0) << "axis " << dequantize_attrs->axis << " is out of range"; } diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 94073f5d3fdc..b116eb9da034 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -52,13 +52,13 @@ bool QuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* quantize_attrs = attrs.as(); int axis = quantize_attrs->axis; auto rank = static_cast(data->shape.size()); + axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; // If zero point and scale are scalar then axis doesnt matter. bool scale_is_scalar = (types[1].as())->shape.size() == 0; bool zp_is_scalar = (types[2].as())->shape.size() == 0; if (!(scale_is_scalar && zp_is_scalar)) { - axis = (axis < 0) ? ((rank > 0) ? data->shape.size() + axis : 0) : axis; ICHECK_LT(axis, rank > 0 ? rank : 1) << "axis " << quantize_attrs->axis << " is out of range"; ICHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range"; } From 298b81d1b7df715e4ba57613868210631852def1 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 31 Aug 2021 15:47:22 -0600 Subject: [PATCH 11/12] Per channel fq2i (#8) * WIP support per-channel quantization * more WIP * More WIP * fix issue with per-channel bias_add * Fix fake quantize tests (#4) * Fixed fake quantize issues. * Formatting. * Cleanup unused imports * Fix real int8 tests. * Add Relu * One more little one (#5) * Fixed fake quantize issues. * Formatting. * Cleanup unused imports * Fix real int8 tests. * Fix requantize shape bug. * Non-working Per-channel Dense * Fix legalization for non spatial operators. (#6) * Fix legalization for non spatial operators. * Fix axis checks for end2end functionality. * fix axis normalization fix lint fix lint again * Fix bug in requantize dimension expansion. * Format. Co-authored-by: Josh Fromm --- src/relay/qnn/op/requantize.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 25214e177b9b..a7d214761b9b 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -139,11 +139,13 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale, auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0); if (!IsEqualScalar(input_zero_point, zero_scalar)) { // Broadcast input zero point if needed. + int rank = static_cast(input_shape.size()); + int axis = (param->axis < 0) ? ((rank > 0) ? rank + param->axis : 0) : param->axis; Expr input_zero_broadcast = ExpandBiasToMatchAxis(Reshape(input_zero_point, { -1, }), - input_shape.size(), {param->axis}); + rank, {axis}); tensor = Subtract(tensor, Cast(input_zero_broadcast, DataType::Int(32))); } From e5d80e5464b91b5dcca5403efc52b830e01dd52a Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 3 Sep 2021 13:44:35 -0600 Subject: [PATCH 12/12] respond to review comments respond to review comments --- include/tvm/ir/affine_type.h | 2 +- python/tvm/ir/affine_type.py | 3 +++ python/tvm/relay/qnn/op/qnn.py | 9 +++++++-- .../tvm/relay/transform/fake_quantization_to_integer.py | 9 +++++---- src/relay/qnn/op/dense.cc | 1 + 5 files changed, 17 insertions(+), 7 deletions(-) diff --git a/include/tvm/ir/affine_type.h b/include/tvm/ir/affine_type.h index 4ecdcae7b958..5726e9eec1f0 100644 --- a/include/tvm/ir/affine_type.h +++ b/include/tvm/ir/affine_type.h @@ -71,7 +71,7 @@ class TensorAffineTypeNode : public AffineTypeNode { RelayExpr zero_point; /*! \brief The data type of this type */ DataType dtype; - /*! \brief The data type of this type */ + /*! \brief The axis for per-channel quantization */ int axis; void VisitAttrs(tvm::AttrVisitor* v) { diff --git a/python/tvm/ir/affine_type.py b/python/tvm/ir/affine_type.py index 852af90673aa..bd77c187af40 100644 --- a/python/tvm/ir/affine_type.py +++ b/python/tvm/ir/affine_type.py @@ -48,6 +48,9 @@ class TensorAffineType(AffineType): dtype : str The content data type. + + axis : int + The axis for per-channel quantization. """ def __init__(self, scale, zero_point, dtype, axis=-1): diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index e74256ec74c3..83b5cf0a831c 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -276,8 +276,10 @@ def conv2d( ): r"""Quantized 2D convolution. - This operator convolves quantized data with quantized kernel. The scale of - the output quantized tensor is the product of the kernel_scale and + This operator convolves quantized data with quantized kernel. + If doing Per-channel quantization, qnn expects the kernel_zero_scale + and optionally the kernel_zero_point will be 1-D vectors instead of scalars. + The scale of the output quantized tensor is the product of the kernel_scale and input_scale of the input quantized tensors. The zero point of the output quantized tensor is 0. By default, the dtype of output is int32. Please also refer to Requantize operator to understand how to scale back the int32 @@ -544,6 +546,9 @@ def dense( `Y = X * W` + If doing Per-channel quantization, qnn expects the kernel_zero_scale + and optionally the kernel_zero_point will be 1-D vectors instead of scalars. + Parameters ---------- data : tvm.relay.Expr diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 39e0db4ad59d..6032dbf92dbc 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -255,12 +255,13 @@ def relu(expr, type_map): arg = expr.args[0] t = type_map[arg] scale_shape = infer_shape(t.scale) - zero = relay.const(0, dtype="float32") - if len(scale_shape) > 0 and scale_shape[0] > 1: + z_p = t.zero_point + assert len(scale_shape) <= 1 + if len(scale_shape) == 1 and scale_shape[0] > 1: b_shape = [1] * len(infer_shape(arg)) b_shape[t.axis] = -1 - zero = relay.op.reshape(relay.op.broadcast_to(zero, scale_shape), b_shape) - zero = relay.qnn.op.quantize(zero, t.scale, t.zero_point, t.axis, t.dtype) + z_p = relay.op.reshape(relay.op.broadcast_to(z_p, scale_shape), b_shape) + zero = relay.op.cast(z_p, t.dtype) return [relay.op.maximum(arg, fold_constant(zero)), t] diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index 2e967550124a..7b733d4777ec 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -62,6 +62,7 @@ bool QnnDenseRel(const Array& types, int num_inputs, const Attrs& attrs, } ICHECK(IsScalarType(types[2], DataType::Int(32))); // input_zero_point ICHECK(IsScalarType(types[4], DataType::Float(32))); // input_scale + // weight_zero_point can be a scalar or a vector of the same shape as the weight_scale AssignType(types[5], DataType::Float(32), param->units, reporter); // weight_scale ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0.";