From 453ff392b4aa31ea0a10e86e706348ac4b23f80e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 22 Sep 2022 08:53:28 +0900 Subject: [PATCH 1/4] [Int8] Fix dtype legalize logic for CPU dot product instruction --- python/tvm/topi/generic/conv2d.py | 15 +++-- tests/python/relay/test_op_level2.py | 89 +++++++++++++++++++++++++++- 2 files changed, 97 insertions(+), 7 deletions(-) diff --git a/python/tvm/topi/generic/conv2d.py b/python/tvm/topi/generic/conv2d.py index 1cb69d593d1a..48b2a2f97146 100644 --- a/python/tvm/topi/generic/conv2d.py +++ b/python/tvm/topi/generic/conv2d.py @@ -477,7 +477,7 @@ def conv2d_alter_int8_common( pt, pl, pb, pr = get_pad_tuple(padding, (kh, kw)) if data_tensor.dtype != data_dtype: - # How to convert data to int8 + # How to convert data to uint8 # Original --> C = A (conv) B # A and B are int8 # C = (A + 128 - 128) (conv) B @@ -485,18 +485,20 @@ def conv2d_alter_int8_common( # where A' = A + 128 # and 128 (conv) B is basically a reduce on CRS axis for weights. # - # How to convert data to uint8 + # How to convert data to int8 # C = (A - 128 + 128) (conv) B # C = (A' conv B) + 128 (conv) B # where A' = A - 128 - if data_dtype == "int8": - # shift data to int8 + if data_dtype == "uint8": + # shift data to uint8 before_shift = relay.add after_shift = relay.subtract + pad_value = 128 else: - # shift data to uint8 + # shift data to int8 before_shift = relay.subtract after_shift = relay.add + pad_value = -128 if attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO": adjust_shift = relay.sum(relay.cast(kernel, dtype="int32"), axis=(0, 1, 2)) @@ -514,7 +516,8 @@ def conv2d_alter_int8_common( # Do external padding as pad value has to be 128. if any(padding): - data = relay.nn.pad(data, pad_width=pad_width, pad_value=128) + data = relay.nn.pad(data, pad_width=pad_width, pad_value=pad_value) + new_attrs["padding"] = (0, 0) # Multiply 128 to adjust shift. diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 84b72e4cffd2..d4f9d1366e25 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -2137,5 +2137,92 @@ def get_subgraph(dtype): np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) +def test_conv2d_int8_alter_dtype(): + data_dtype = "uint8" + target = "llvm --device arm_cpu -mattr=+v8.2a,+dotprod" + dot_product_instr = "sdot" + + # data_dtype = "int8" + # target = "llvm -mcpu=cascadelake" + # dot_product_instr = "vpdpbusd" + + weight_dtype = "int8" + + def get_conv2d_nchw( + d_shape, + w_shape, + padding, + strides=(1, 1), + ): + out_dtype = "int32" + + data = relay.var("data", shape=d_shape, dtype=data_dtype) + weight = relay.var("weight", shape=w_shape, dtype=weight_dtype) + out_channel = w_shape[0] + return relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=w_shape[2:], + channels=out_channel, + padding=padding, + strides=strides, + out_dtype=out_dtype, + ) + + I, O, H, W = 64, 64, 56, 56 + kH = kW = 3 + padding = (1, 1) + strides = (1, 1) + + data_shape = (1, I, H, W) + weight_shape = (O, I, kH, kW) + bias_shape = (1, weight_shape[0], 1, 1) + + bias = relay.var("bias", shape=bias_shape, dtype="int32") + + conv2d = get_conv2d_nchw(data_shape, weight_shape, padding, strides=strides) + bias_add = relay.add(conv2d, bias) + mod = tvm.IRModule.from_expr(bias_add) + + if data_dtype == "uint8": + data_np = np.random.uniform(0, 50, size=data_shape).astype("uint8") + else: + data_np = np.random.uniform(-10, 10, size=data_shape).astype("int8") + + if weight_dtype == "uint8": + weight_np = np.random.uniform(0, 255, size=weight_shape).astype("uint8") + else: + weight_np = np.random.uniform(-128, 127, size=weight_shape).astype("int8") + + bias_np = np.random.randint(low=-127, high=128, size=bias_shape).astype("int32") + params = {"weight": weight_np, "bias": bias_np} + + ref = ( + relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm") + .evaluate()(*[data_np, weight_np, bias_np]) + .numpy() + ) + + dev = tvm.cpu(0) + + with tvm.transform.PassContext( + opt_level=3, + ): + lib = relay.build(mod, target=target, params=params) + + assert dot_product_instr in lib.lib.get_source("asm") + + rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + rt_mod.set_input("data", data_np) + + rt_mod.run() + + out = rt_mod.get_output(0).numpy() + + np.testing.assert_equal(out, ref) + + if __name__ == "__main__": - tvm.testing.main() + # tvm.testing.main() + test_conv2d_int8_alter_dtype() From be081e71a9793df1805ca06547287287e19fd60d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 22 Sep 2022 09:16:27 +0900 Subject: [PATCH 2/4] update test --- tests/python/relay/test_op_level2.py | 81 ++++++++++++---------------- 1 file changed, 35 insertions(+), 46 deletions(-) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index d4f9d1366e25..3eb195b0c670 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -2138,26 +2138,16 @@ def get_subgraph(dtype): def test_conv2d_int8_alter_dtype(): - data_dtype = "uint8" - target = "llvm --device arm_cpu -mattr=+v8.2a,+dotprod" - dot_product_instr = "sdot" - - # data_dtype = "int8" - # target = "llvm -mcpu=cascadelake" - # dot_product_instr = "vpdpbusd" - - weight_dtype = "int8" - def get_conv2d_nchw( d_shape, w_shape, - padding, - strides=(1, 1), + data_dtype, ): out_dtype = "int32" - + strides=(1, 1) + padding = (1, 1) data = relay.var("data", shape=d_shape, dtype=data_dtype) - weight = relay.var("weight", shape=w_shape, dtype=weight_dtype) + weight = relay.var("weight", shape=w_shape, dtype="int8") out_channel = w_shape[0] return relay.nn.conv2d( data=data, @@ -2171,56 +2161,55 @@ def get_conv2d_nchw( I, O, H, W = 64, 64, 56, 56 kH = kW = 3 - padding = (1, 1) - strides = (1, 1) data_shape = (1, I, H, W) weight_shape = (O, I, kH, kW) bias_shape = (1, weight_shape[0], 1, 1) bias = relay.var("bias", shape=bias_shape, dtype="int32") + bias_np = np.random.randint(low=-127, high=128, size=bias_shape).astype("int32") + weight_np = np.random.uniform(-128, 127, size=weight_shape).astype("int8") - conv2d = get_conv2d_nchw(data_shape, weight_shape, padding, strides=strides) - bias_add = relay.add(conv2d, bias) - mod = tvm.IRModule.from_expr(bias_add) - - if data_dtype == "uint8": - data_np = np.random.uniform(0, 50, size=data_shape).astype("uint8") - else: - data_np = np.random.uniform(-10, 10, size=data_shape).astype("int8") + for data_dtype, target, dot_product_instr in [("uint8", "llvm --device arm_cpu -mattr=+v8.2a,+dotprod", "sdot"), + ("int8", "llvm -mcpu=cascadelake", "vpdpbusd")]: + conv2d = get_conv2d_nchw(data_shape, weight_shape, data_dtype) + bias_add = relay.add(conv2d, bias) + mod = tvm.IRModule.from_expr(bias_add) - if weight_dtype == "uint8": - weight_np = np.random.uniform(0, 255, size=weight_shape).astype("uint8") - else: - weight_np = np.random.uniform(-128, 127, size=weight_shape).astype("int8") + if data_dtype == "uint8": + data_np = np.random.uniform(0, 255, size=data_shape).astype("uint8") + else: + data_np = np.random.uniform(-128, 127, size=data_shape).astype("int8") - bias_np = np.random.randint(low=-127, high=128, size=bias_shape).astype("int32") - params = {"weight": weight_np, "bias": bias_np} + params = {"weight": weight_np, "bias": bias_np} - ref = ( - relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm") - .evaluate()(*[data_np, weight_np, bias_np]) - .numpy() - ) + ref = ( + relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm") + .evaluate()(*[data_np, weight_np, bias_np]) + .numpy() + ) - dev = tvm.cpu(0) + dev = tvm.cpu(0) - with tvm.transform.PassContext( - opt_level=3, - ): - lib = relay.build(mod, target=target, params=params) + try: + with tvm.transform.PassContext( + opt_level=3, + ): + lib = relay.build(mod, target=target, params=params) - assert dot_product_instr in lib.lib.get_source("asm") + assert dot_product_instr in lib.lib.get_source("asm") - rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) - rt_mod.set_input("data", data_np) + rt_mod.set_input("data", data_np) - rt_mod.run() + rt_mod.run() - out = rt_mod.get_output(0).numpy() + out = rt_mod.get_output(0).numpy() - np.testing.assert_equal(out, ref) + np.testing.assert_equal(out, ref) + except Exception as _: + print("skipping target", target) if __name__ == "__main__": From 3228a5b3919faf9860545e83cd877fc674e649d2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 22 Sep 2022 09:18:42 +0900 Subject: [PATCH 3/4] skip test for now --- tests/python/relay/test_op_level2.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 3eb195b0c670..3f6ba7fb7ad3 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -2137,6 +2137,7 @@ def get_subgraph(dtype): np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) +@pytest.mark.skip("Requires cascadelake or ARM v8.2") def test_conv2d_int8_alter_dtype(): def get_conv2d_nchw( d_shape, @@ -2191,27 +2192,23 @@ def get_conv2d_nchw( dev = tvm.cpu(0) - try: - with tvm.transform.PassContext( - opt_level=3, - ): - lib = relay.build(mod, target=target, params=params) + with tvm.transform.PassContext( + opt_level=3, + ): + lib = relay.build(mod, target=target, params=params) - assert dot_product_instr in lib.lib.get_source("asm") + assert dot_product_instr in lib.lib.get_source("asm") - rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) - rt_mod.set_input("data", data_np) + rt_mod.set_input("data", data_np) - rt_mod.run() + rt_mod.run() - out = rt_mod.get_output(0).numpy() + out = rt_mod.get_output(0).numpy() - np.testing.assert_equal(out, ref) - except Exception as _: - print("skipping target", target) + np.testing.assert_equal(out, ref) if __name__ == "__main__": - # tvm.testing.main() - test_conv2d_int8_alter_dtype() + tvm.testing.main() From 316bb2345408d542f12e78d1b4c49e030222bfab Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 22 Sep 2022 09:24:01 +0900 Subject: [PATCH 4/4] black --- tests/python/relay/test_op_level2.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 3f6ba7fb7ad3..6a895aaf0518 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -2145,7 +2145,7 @@ def get_conv2d_nchw( data_dtype, ): out_dtype = "int32" - strides=(1, 1) + strides = (1, 1) padding = (1, 1) data = relay.var("data", shape=d_shape, dtype=data_dtype) weight = relay.var("weight", shape=w_shape, dtype="int8") @@ -2171,8 +2171,10 @@ def get_conv2d_nchw( bias_np = np.random.randint(low=-127, high=128, size=bias_shape).astype("int32") weight_np = np.random.uniform(-128, 127, size=weight_shape).astype("int8") - for data_dtype, target, dot_product_instr in [("uint8", "llvm --device arm_cpu -mattr=+v8.2a,+dotprod", "sdot"), - ("int8", "llvm -mcpu=cascadelake", "vpdpbusd")]: + for data_dtype, target, dot_product_instr in [ + ("uint8", "llvm --device arm_cpu -mattr=+v8.2a,+dotprod", "sdot"), + ("int8", "llvm -mcpu=cascadelake", "vpdpbusd"), + ]: conv2d = get_conv2d_nchw(data_shape, weight_shape, data_dtype) bias_add = relay.add(conv2d, bias) mod = tvm.IRModule.from_expr(bias_add)