diff --git a/python/tvm/topi/generic/conv2d.py b/python/tvm/topi/generic/conv2d.py index 1cb69d593d1a..48b2a2f97146 100644 --- a/python/tvm/topi/generic/conv2d.py +++ b/python/tvm/topi/generic/conv2d.py @@ -477,7 +477,7 @@ def conv2d_alter_int8_common( pt, pl, pb, pr = get_pad_tuple(padding, (kh, kw)) if data_tensor.dtype != data_dtype: - # How to convert data to int8 + # How to convert data to uint8 # Original --> C = A (conv) B # A and B are int8 # C = (A + 128 - 128) (conv) B @@ -485,18 +485,20 @@ def conv2d_alter_int8_common( # where A' = A + 128 # and 128 (conv) B is basically a reduce on CRS axis for weights. # - # How to convert data to uint8 + # How to convert data to int8 # C = (A - 128 + 128) (conv) B # C = (A' conv B) + 128 (conv) B # where A' = A - 128 - if data_dtype == "int8": - # shift data to int8 + if data_dtype == "uint8": + # shift data to uint8 before_shift = relay.add after_shift = relay.subtract + pad_value = 128 else: - # shift data to uint8 + # shift data to int8 before_shift = relay.subtract after_shift = relay.add + pad_value = -128 if attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO": adjust_shift = relay.sum(relay.cast(kernel, dtype="int32"), axis=(0, 1, 2)) @@ -514,7 +516,8 @@ def conv2d_alter_int8_common( # Do external padding as pad value has to be 128. if any(padding): - data = relay.nn.pad(data, pad_width=pad_width, pad_value=128) + data = relay.nn.pad(data, pad_width=pad_width, pad_value=pad_value) + new_attrs["padding"] = (0, 0) # Multiply 128 to adjust shift. diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 84b72e4cffd2..6a895aaf0518 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -2137,5 +2137,80 @@ def get_subgraph(dtype): np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) +@pytest.mark.skip("Requires cascadelake or ARM v8.2") +def test_conv2d_int8_alter_dtype(): + def get_conv2d_nchw( + d_shape, + w_shape, + data_dtype, + ): + out_dtype = "int32" + strides = (1, 1) + padding = (1, 1) + data = relay.var("data", shape=d_shape, dtype=data_dtype) + weight = relay.var("weight", shape=w_shape, dtype="int8") + out_channel = w_shape[0] + return relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=w_shape[2:], + channels=out_channel, + padding=padding, + strides=strides, + out_dtype=out_dtype, + ) + + I, O, H, W = 64, 64, 56, 56 + kH = kW = 3 + + data_shape = (1, I, H, W) + weight_shape = (O, I, kH, kW) + bias_shape = (1, weight_shape[0], 1, 1) + + bias = relay.var("bias", shape=bias_shape, dtype="int32") + bias_np = np.random.randint(low=-127, high=128, size=bias_shape).astype("int32") + weight_np = np.random.uniform(-128, 127, size=weight_shape).astype("int8") + + for data_dtype, target, dot_product_instr in [ + ("uint8", "llvm --device arm_cpu -mattr=+v8.2a,+dotprod", "sdot"), + ("int8", "llvm -mcpu=cascadelake", "vpdpbusd"), + ]: + conv2d = get_conv2d_nchw(data_shape, weight_shape, data_dtype) + bias_add = relay.add(conv2d, bias) + mod = tvm.IRModule.from_expr(bias_add) + + if data_dtype == "uint8": + data_np = np.random.uniform(0, 255, size=data_shape).astype("uint8") + else: + data_np = np.random.uniform(-128, 127, size=data_shape).astype("int8") + + params = {"weight": weight_np, "bias": bias_np} + + ref = ( + relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm") + .evaluate()(*[data_np, weight_np, bias_np]) + .numpy() + ) + + dev = tvm.cpu(0) + + with tvm.transform.PassContext( + opt_level=3, + ): + lib = relay.build(mod, target=target, params=params) + + assert dot_product_instr in lib.lib.get_source("asm") + + rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + rt_mod.set_input("data", data_np) + + rt_mod.run() + + out = rt_mod.get_output(0).numpy() + + np.testing.assert_equal(out, ref) + + if __name__ == "__main__": tvm.testing.main()