Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions python/tvm/topi/generic/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,26 +477,28 @@ def conv2d_alter_int8_common(
pt, pl, pb, pr = get_pad_tuple(padding, (kh, kw))

if data_tensor.dtype != data_dtype:
# How to convert data to int8
# How to convert data to uint8
# Original --> C = A (conv) B
# A and B are int8
# C = (A + 128 - 128) (conv) B
# C = (A' conv B) - 128 (conv) B
# where A' = A + 128
# and 128 (conv) B is basically a reduce on CRS axis for weights.
#
# How to convert data to uint8
# How to convert data to int8
# C = (A - 128 + 128) (conv) B
# C = (A' conv B) + 128 (conv) B
# where A' = A - 128
if data_dtype == "int8":
# shift data to int8
if data_dtype == "uint8":
# shift data to uint8
before_shift = relay.add
after_shift = relay.subtract
pad_value = 128
else:
# shift data to uint8
# shift data to int8
before_shift = relay.subtract
after_shift = relay.add
pad_value = -128

if attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO":
adjust_shift = relay.sum(relay.cast(kernel, dtype="int32"), axis=(0, 1, 2))
Expand All @@ -514,7 +516,8 @@ def conv2d_alter_int8_common(

# Do external padding as pad value has to be 128.
if any(padding):
data = relay.nn.pad(data, pad_width=pad_width, pad_value=128)
data = relay.nn.pad(data, pad_width=pad_width, pad_value=pad_value)

new_attrs["padding"] = (0, 0)

# Multiply 128 to adjust shift.
Expand Down
75 changes: 75 additions & 0 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2137,5 +2137,80 @@ def get_subgraph(dtype):
np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)


@pytest.mark.skip("Requires cascadelake or ARM v8.2")
def test_conv2d_int8_alter_dtype():
def get_conv2d_nchw(
d_shape,
w_shape,
data_dtype,
):
out_dtype = "int32"
strides = (1, 1)
padding = (1, 1)
data = relay.var("data", shape=d_shape, dtype=data_dtype)
weight = relay.var("weight", shape=w_shape, dtype="int8")
out_channel = w_shape[0]
return relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=w_shape[2:],
channels=out_channel,
padding=padding,
strides=strides,
out_dtype=out_dtype,
)

I, O, H, W = 64, 64, 56, 56
kH = kW = 3

data_shape = (1, I, H, W)
weight_shape = (O, I, kH, kW)
bias_shape = (1, weight_shape[0], 1, 1)

bias = relay.var("bias", shape=bias_shape, dtype="int32")
bias_np = np.random.randint(low=-127, high=128, size=bias_shape).astype("int32")
weight_np = np.random.uniform(-128, 127, size=weight_shape).astype("int8")

for data_dtype, target, dot_product_instr in [
("uint8", "llvm --device arm_cpu -mattr=+v8.2a,+dotprod", "sdot"),
("int8", "llvm -mcpu=cascadelake", "vpdpbusd"),
]:
conv2d = get_conv2d_nchw(data_shape, weight_shape, data_dtype)
bias_add = relay.add(conv2d, bias)
mod = tvm.IRModule.from_expr(bias_add)

if data_dtype == "uint8":
data_np = np.random.uniform(0, 255, size=data_shape).astype("uint8")
else:
data_np = np.random.uniform(-128, 127, size=data_shape).astype("int8")

params = {"weight": weight_np, "bias": bias_np}

ref = (
relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm")
.evaluate()(*[data_np, weight_np, bias_np])
.numpy()
)

dev = tvm.cpu(0)

with tvm.transform.PassContext(
opt_level=3,
):
lib = relay.build(mod, target=target, params=params)

assert dot_product_instr in lib.lib.get_source("asm")

rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))

rt_mod.set_input("data", data_np)

rt_mod.run()

out = rt_mod.get_output(0).numpy()

np.testing.assert_equal(out, ref)


if __name__ == "__main__":
tvm.testing.main()