diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 1adde9a4a430..70e4ad9ceace 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -156,6 +156,25 @@ def conv2d(expr, type_map): return [out, TensorAffineType(conv_scale, conv_zp, out.attrs.out_dtype, out_axis.value)] +@register_fake_quantization_to_integer("nn.conv2d_transpose") +def conv2d_transpose(expr, type_map): + """Rewrite a conv2d_transpose op""" + attrs = {**expr.attrs} + attrs.pop("out_dtype") + x, weight = expr.args + x_t = type_map[x] + w_t = type_map[weight] + conv_scale = fold_constant(x_t.scale * w_t.scale) + conv_zp = get_zeros(conv_scale) + + out = relay.qnn.op.conv2d_transpose( + x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs + ) + out_layout = attrs["out_layout"] if attrs["out_layout"] != "" else attrs["data_layout"] + out_axis = bijective_layout(out_layout, "NCHW").backward_index(list(range(4)))[1] + return [out, TensorAffineType(conv_scale, conv_zp, out.attrs.out_dtype, out_axis.value)] + + @register_fake_quantization_to_integer("nn.dense") def dense(expr, type_map): """Rewrite a dense op""" diff --git a/tests/python/relay/test_pass_fake_quantization_to_integer.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py index c49d837ed920..07413b83de93 100644 --- a/tests/python/relay/test_pass_fake_quantization_to_integer.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -89,6 +89,28 @@ def test_fake_quantize_conv_per_channel(): compare_fq_to_int(op, [x_np, w_np], allow_rounding_error=True) +def test_fake_quantize_transposeconv(): + for out_dtype in ["int8", "uint8"]: + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + w = relay.var("w", shape=[3, 16, 5, 5], dtype="int8") + one = relay.const(1.0) + zero = relay.const(0) + + op = relay.op.nn.conv2d_transpose( + relay.qnn.op.dequantize(x, relay.const(2.0), zero), + relay.qnn.op.dequantize(w, relay.const(0.5), zero), + kernel_size=[5, 5], + data_layout="NCHW", + kernel_layout="IOHW", + ) + op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + w_np = np.random.randint(-128, 127, size=[3, 16, 5, 5], dtype="int8") + + compare_fq_to_int(op, [x_np, w_np]) + + def test_fake_quantize_dense(): for out_dtype in ["int8", "uint8"]: x = relay.var("x", shape=[128, 64], dtype="int8")