diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 8367a681d022..b4db412700a7 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -334,17 +334,37 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): cudnn_impl = True if layout == "NCHW": - # TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8. assert kernel_layout == "OIHW" - strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True), - wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw), - name="group_conv2d_nchw.cuda", - ) + _, channels, _, _ = get_const_tuple(data.shape) + out_channels, in_channels, _, _ = get_const_tuple(kernel.shape) + oc_chunk = out_channels // 4 + ic_chunk = in_channels // 4 + + if ( + data.dtype in ["int8", "uint8"] + and kernel.dtype in ["int8", "uint8"] + and channels % groups == 0 + and out_channels % groups == 0 + and channels % 4 == 0 + and out_channels % 4 == 0 + and groups <= oc_chunk + and groups <= ic_chunk + ): + strategy.add_implementation( + wrap_compute_conv2d(topi.cuda.group_conv2d_nchw_int8, has_groups=True), + wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw_int8), + name="group_conv2d_nchw_int8.cuda", + ) + else: + strategy.add_implementation( + wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True), + wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw), + name="group_conv2d_nchw.cuda", + ) elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]: assert kernel_layout == "OIHW4o4i" strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True), + wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, has_groups=True), wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8), name="group_conv2d_NCHWc_int8.cuda", ) diff --git a/python/tvm/topi/cuda/group_conv2d_nchw.py b/python/tvm/topi/cuda/group_conv2d_nchw.py index 2af011700235..d75cfffc1af8 100644 --- a/python/tvm/topi/cuda/group_conv2d_nchw.py +++ b/python/tvm/topi/cuda/group_conv2d_nchw.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name +# pylint: disable=no-value-for-parameter """The template for cuda group_conv2d_nchw""" import tvm from tvm import te @@ -23,11 +24,28 @@ from .injective import schedule_injective_from_existing from .tensor_intrin import dp4a from ..nn.pad import pad +from ..nn.conv2d import unpack_NCHWc_to_nchw from ..nn.utils import get_pad_tuple from ..utils import traverse_inline, get_const_tuple, get_const_int from .. import nn +def group_conv2d_nchw_int8(data, kernel, strides, padding, dilation, groups, out_dtype="float32"): + """Compute group_conv2d internally using group_conv2d_nchwc layout for int8 dtype""" + assert data.dtype in ("int8", "uint8") + assert kernel.dtype in ("int8", "uint8") + assert data.dtype == kernel.dtype + packed_out = group_conv2d_NCHWc_int8( + data, kernel, strides, padding, dilation, groups, out_dtype + ) + return unpack_NCHWc_to_nchw(packed_out, out_dtype) + + +def schedule_group_conv2d_nchw_int8(outs): + """Create schedule for tensors""" + return schedule_group_conv2d_NCHWc_int8(outs) + + @autotvm.register_topi_compute("group_conv2d_nchw.cuda") def group_conv2d_nchw(_, data, kernel, stride, padding, dilation, groups, out_dtype="float32"): return nn.group_conv2d_nchw(data, kernel, stride, padding, dilation, groups, out_dtype) @@ -422,7 +440,13 @@ def _schedule_group_conv2d_NCHWc_int8(cfg, s, output): oc_chunk = get_const_int(output.shape[1]) # tile and bind spatial axes - n, f, y, x, c = s[output].op.axis + if len(s[output].op.axis) == 5: + n, f, y, x, c = s[output].op.axis + else: + # For task extraction of auto-tuning, the expected output is 4D. Since auto-tuning tasks + # are created from scratch, therefore the real auto-tuning will still happen on 5D output. + n, f, y, x = s[output].op.axis + cfg.define_split("tile_n", n, num_outputs=4) cfg.define_split("tile_g", cfg.axis(groups), num_outputs=2) cfg.define_split("tile_f", cfg.axis(oc_chunk // groups), num_outputs=4) diff --git a/tests/python/topi/python/test_topi_group_conv2d.py b/tests/python/topi/python/test_topi_group_conv2d.py index e5a2fe7f28ab..55b24feece93 100644 --- a/tests/python/topi/python/test_topi_group_conv2d.py +++ b/tests/python/topi/python/test_topi_group_conv2d.py @@ -30,6 +30,22 @@ import tvm.testing +def _transform_data(data, bn): + # NCHW -> NCHW[x]c + batch_size, channel, height, width = data.shape + data = np.reshape(data, (batch_size, channel // bn, bn, height, width)) + data = np.transpose(data, (0, 1, 3, 4, 2)) + return data + + +def _transform_kernel(kernel, ic_bn, oc_bn): + # OIHW -> OIHW[x]o[x]i + out_channel, in_channel, kh, kw = kernel.shape + kernel = np.reshape(kernel, (out_channel // oc_bn, oc_bn, in_channel // ic_bn, ic_bn, kh, kw)) + kernel = np.transpose(kernel, (0, 2, 4, 5, 1, 3)) + return kernel + + _group_conv2d_nchw_implement = { "generic": (topi.nn.group_conv2d_nchw, topi.generic.schedule_group_conv2d_nchw), "gpu": (topi.cuda.group_conv2d_nchw, topi.cuda.schedule_group_conv2d_nchw), @@ -154,6 +170,7 @@ def check_target(target): oc_block_factor = 4 +ic_block_factor = 4 def verify_group_conv2d_NCHWc_int8( @@ -176,6 +193,151 @@ def verify_group_conv2d_NCHWc_int8( in_height = in_width = in_size + A = te.placeholder( + (batch, in_channel // ic_block_factor, in_height, in_width, ic_block_factor), + name="A", + dtype="int8", + ) + W = te.placeholder( + ( + num_filter // oc_block_factor, + (in_channel // groups) // ic_block_factor, + kernel, + kernel, + oc_block_factor, + ic_block_factor, + ), + name="W", + dtype="int8", + ) + bias = te.placeholder( + (num_filter // oc_block_factor, 1, 1, oc_block_factor), name="bias", dtype="int8" + ) + + bias_shape = get_const_tuple(bias.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_group_conv2d.verify_group_conv2d_NCHWc_int8") + def get_ref_data(): + a_np = np.random.randint( + low=-128, high=127, size=(batch, in_channel, in_height, in_width) + ).astype(dtype) + w_np = np.random.randint( + low=-128, high=128, size=(num_filter, in_channel // groups, kernel, kernel) + ).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) + c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding, groups).astype( + dtype + ) + + # convert to NCHWc + _, _, out_height, out_width = c_np.shape + c_np = c_np.reshape( + (batch, num_filter // oc_block_factor, oc_block_factor, out_height, out_width) + ).transpose(0, 1, 3, 4, 2) + + if add_bias: + b_np = np.random.uniform(size=bias_shape).astype(dtype) + c_np += b_np + if add_relu: + c_np = np.maximum(c_np, 0) + + return ( + _transform_data(a_np, ic_block_factor), + _transform_kernel(w_np, ic_block_factor, oc_block_factor), + b_np, + c_np, + ) + + a_np, w_np, b_np, c_np = get_ref_data() + + def check_target(target): + dev = tvm.device(target, 0) + if not tvm.testing.device_enabled(target): + print("Skip because %s is not enabled" % target) + return + if target == "cuda" and not tvm.contrib.nvcc.have_int8(dev.compute_version): + print("Skip because int8 intrinsics are not available") + return + + print("Running on target: %s" % target) + with tvm.target.Target(target): + C = topi.cuda.group_conv2d_NCHWc_int8(A, W, stride, padding, dilation, groups, dtype) + if add_bias: + C = topi.add(C, bias) + if add_relu: + C = topi.nn.relu(C) + s = topi.cuda.schedule_group_conv2d_NCHWc_int8([C]) + + a = tvm.nd.array(a_np, dev) + w = tvm.nd.array(w_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev) + if add_bias: + func = tvm.build( + s, + [A, W, bias, C], + target, + name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" + % ( + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + dilation, + groups, + ), + ) + func(a, w, b, c) + else: + func = tvm.build( + s, + [A, W, C], + target, + name="relu_%d_%d_%d_%d_%d_%d_%d_%d_%d" + % ( + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + dilation, + groups, + ), + ) + func(a, w, c) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5) + + for target in ["cuda"]: + check_target(target) + + +def verify_group_conv2d_nchw_int8( + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding, + dilation, + groups, + add_bias=False, + add_relu=False, +): + print( + "Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d)" + % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, groups) + ) + + in_height = in_width = in_size + A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype="int8") W = te.placeholder((num_filter, in_channel // groups, kernel, kernel), name="W", dtype="int8") bias = te.placeholder( @@ -187,7 +349,7 @@ def verify_group_conv2d_NCHWc_int8( bias_shape = get_const_tuple(bias.shape) dtype = A.dtype - @memoize("topi.tests.test_topi_group_conv2d.verify_group_conv2d_NCHWc_int8") + @memoize("topi.tests.test_topi_group_conv2d.verify_group_conv2d_nchw_int8") def get_ref_data(): a_np = np.random.randint(low=-128, high=127, size=a_shape).astype(dtype) w_np = np.random.randint(low=-128, high=128, size=w_shape).astype(dtype) @@ -442,6 +604,30 @@ def test_group_conv2d_NCHWc_int8(): verify_group_conv2d_NCHWc_int8(9, 128, 56, 128, 3, 1, 1, 1, 32) +@tvm.testing.requires_cuda +def test_group_conv2d_nchw_int8(): + with Int8Fallback(): + # ResNeXt-50 workload + verify_group_conv2d_nchw_int8(1, 128, 56, 128, 3, 1, 1, 1, 32) + verify_group_conv2d_nchw_int8(1, 256, 56, 256, 3, 2, 1, 1, 32) + verify_group_conv2d_nchw_int8(1, 256, 28, 256, 3, 1, 1, 1, 32) + verify_group_conv2d_nchw_int8(1, 512, 28, 512, 3, 2, 1, 1, 32) + verify_group_conv2d_nchw_int8(1, 512, 14, 512, 3, 1, 1, 1, 32) + verify_group_conv2d_nchw_int8(1, 1024, 14, 1024, 3, 2, 1, 1, 32) + verify_group_conv2d_nchw_int8(1, 1024, 7, 1024, 3, 1, 1, 1, 32) + + # bias, relu + verify_group_conv2d_nchw_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True) + verify_group_conv2d_nchw_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_bias=True) + verify_group_conv2d_nchw_int8(1, 128, 56, 128, 3, 1, 1, 1, 32, add_relu=True, add_bias=True) + # dilation + verify_group_conv2d_nchw_int8(1, 128, 56, 128, 3, 1, 1, 2, 32) + + # batch size + verify_group_conv2d_nchw_int8(2, 128, 56, 128, 3, 1, 1, 1, 32) + verify_group_conv2d_nchw_int8(9, 128, 56, 128, 3, 1, 1, 1, 32) + + def test_group_conv2d_nhwc(): # ResNeXt-50 workload verify_group_conv2d_nhwc(1, 128, 56, 128, 3, 1, 1, 1, 32) @@ -468,4 +654,5 @@ def test_group_conv2d_nhwc(): if __name__ == "__main__": test_group_conv2d_nchw() test_group_conv2d_NCHWc_int8() + test_group_conv2d_nchw_int8() test_group_conv2d_nhwc()