From 4d09d7d831c7c779f6119fcd828f9a35a21c2a48 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 22 Oct 2018 10:29:37 +0800 Subject: [PATCH 1/5] Tiling in batch in int8 conv2d --- topi/python/topi/cuda/conv2d_int8.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_int8.py b/topi/python/topi/cuda/conv2d_int8.py index 053c9bc6bd31..d72a2a903a02 100644 --- a/topi/python/topi/cuda/conv2d_int8.py +++ b/topi/python/topi/cuda/conv2d_int8.py @@ -9,7 +9,7 @@ from ..nn.conv2d import conv2d_NCHWc_int8_prepacked from ..nn.pad import pad from ..nn.util import get_pad_tuple -from ..util import get_const_tuple, get_const_int, traverse_inline +from ..util import get_const_tuple, traverse_inline def _conv2d_NCHWc_int8_arg_to_workload(data, kernel, stride, padding, out_dtype): @@ -183,7 +183,7 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed): _schedule_injective(packed_data.op, s) _schedule_injective(packed_kernel.op, s) else: - kernel = packed_data + kernel = packed_kernel if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: s[kernel].compute_inline() @@ -191,7 +191,6 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed): if pad_data != packed_data: s[pad_data].compute_inline() - batch = get_const_int(packed_data.shape[0]) if isinstance(stride, int): stride_h = stride_w = stride else: @@ -210,25 +209,24 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed): # tile and bind spatial axes n, f, y, x, c = s[output].op.axis + cfg.define_split("tile_n", cfg.axis(n), num_outputs=4) cfg.define_split("tile_f", cfg.axis(f), num_outputs=4) cfg.define_split("tile_y", cfg.axis(y), num_outputs=4) cfg.define_split("tile_x", cfg.axis(x), num_outputs=4) + # this is the scope to attach global config inside this kernel + kernel_scope, n = s[output].split(n, nparts=1) + + bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n) bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) - # this is the scope to attach global config inside this kernel - kernel_scope, n = s[output].split(n, nparts=1) - - max_block_z = 128 - if batch > max_block_z: - _, n = s[output].split(n, factor=max_block_z) - s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi) - fused_byx = s[output].fuse(by, bx) - s[output].bind(n, tvm.thread_axis("blockIdx.z")) + s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi) + s[output].bind(bn, tvm.thread_axis("blockIdx.z")) s[output].bind(bf, tvm.thread_axis("blockIdx.y")) - s[output].bind(fused_byx, tvm.thread_axis("blockIdx.x")) + s[output].bind(s[output].fuse(by, bx), tvm.thread_axis("blockIdx.x")) + s[output].bind(vn, tvm.thread_axis("vthread")) s[output].bind(vf, tvm.thread_axis("vthread")) s[output].bind(vy, tvm.thread_axis("vthread")) s[output].bind(vx, tvm.thread_axis("vthread")) From 6e8668dfe0837495b053ecdd42d49d9326e37ded Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 25 Oct 2018 17:52:24 +0800 Subject: [PATCH 2/5] Bind tn and fuse tyx --- topi/python/topi/cuda/conv2d_int8.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_int8.py b/topi/python/topi/cuda/conv2d_int8.py index d72a2a903a02..05f6f4ad1026 100644 --- a/topi/python/topi/cuda/conv2d_int8.py +++ b/topi/python/topi/cuda/conv2d_int8.py @@ -230,11 +230,18 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed): s[output].bind(vf, tvm.thread_axis("vthread")) s[output].bind(vy, tvm.thread_axis("vthread")) s[output].bind(vx, tvm.thread_axis("vthread")) - s[output].bind(tf, tvm.thread_axis("threadIdx.z")) - s[output].bind(ty, tvm.thread_axis("threadIdx.y")) - s[output].bind(tx, tvm.thread_axis("threadIdx.x")) - s[conv].compute_at(s[output], tx) + s[output].bind(tn, tvm.thread_axis("threadIdx.z")) + s[output].bind(tf, tvm.thread_axis("threadIdx.y")) + tyx = s[output].fuse(ty, tx) + s[output].bind(tyx, tvm.thread_axis("threadIdx.x")) + + # number of threads + n_tz = cfg["tile_n"].size[2] + n_ty = cfg["tile_f"].size[2] + n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2] + + s[conv].compute_at(s[output], tyx) # tile and bind reduction axes n, f, y, x, c = s[conv].op.axis @@ -270,9 +277,9 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed): fused = s[load].fuse(n, f, y, x, oc_chunk) s[load].vectorize(c) - fused, tx = s[load].split(fused, factor=cfg["tile_x"].size[2]) - fused, ty = s[load].split(fused, factor=cfg["tile_y"].size[2]) - fused, tz = s[load].split(fused, factor=cfg["tile_f"].size[2]) + fused, tx = s[load].split(fused, factor=n_tx) + fused, ty = s[load].split(fused, factor=n_ty) + fused, tz = s[load].split(fused, factor=n_tz) s[load].bind(tz, tvm.thread_axis("threadIdx.z")) s[load].bind(ty, tvm.thread_axis("threadIdx.y")) s[load].bind(tx, tvm.thread_axis("threadIdx.x")) From fb3ff4166e161356a91ebf08b3bca0740bb12cbc Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 25 Oct 2018 17:59:07 +0800 Subject: [PATCH 3/5] Add test cases where batch > 1 --- topi/tests/python/test_topi_conv2d_int8.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/topi/tests/python/test_topi_conv2d_int8.py b/topi/tests/python/test_topi_conv2d_int8.py index 2b85b2b97cb1..223b356361fe 100644 --- a/topi/tests/python/test_topi_conv2d_int8.py +++ b/topi/tests/python/test_topi_conv2d_int8.py @@ -172,6 +172,10 @@ def test_conv2d_nchw(): verify_conv2d_NCHWc_int8(1, 2048, 8, 192, 1, 1, 0) verify_conv2d_NCHWc_int8(1, 1024, 19, 84, 3, 1, 1) + # batch > 1 + verify_conv2d_NCHWc_int8(7, 32, 149, 32, 3, 1, 0) + verify_conv2d_NCHWc_int8(8, 32, 149, 32, 3, 1, 0) + verify_conv2d_NCHWc_int8(32, 32, 149, 32, 3, 1, 0) if __name__ == "__main__": test_conv2d_nchw() From 869d5a27b719438581213e95c005282fff45c301 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 26 Oct 2018 10:55:15 +0800 Subject: [PATCH 4/5] Make fuse tn,tf,ty,tx tuneable --- topi/python/topi/cuda/conv2d_int8.py | 33 ++++++++++++++++++---------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/topi/python/topi/cuda/conv2d_int8.py b/topi/python/topi/cuda/conv2d_int8.py index 05f6f4ad1026..9d3757c35fbb 100644 --- a/topi/python/topi/cuda/conv2d_int8.py +++ b/topi/python/topi/cuda/conv2d_int8.py @@ -231,17 +231,28 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed): s[output].bind(vy, tvm.thread_axis("vthread")) s[output].bind(vx, tvm.thread_axis("vthread")) - s[output].bind(tn, tvm.thread_axis("threadIdx.z")) - s[output].bind(tf, tvm.thread_axis("threadIdx.y")) - tyx = s[output].fuse(ty, tx) - s[output].bind(tyx, tvm.thread_axis("threadIdx.x")) - - # number of threads - n_tz = cfg["tile_n"].size[2] - n_ty = cfg["tile_f"].size[2] - n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2] - - s[conv].compute_at(s[output], tyx) + cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf + if cfg["fuse_yx"].val: + s[output].bind(tn, tvm.thread_axis("threadIdx.z")) + s[output].bind(tf, tvm.thread_axis("threadIdx.y")) + tyx = s[output].fuse(ty, tx) + s[output].bind(tyx, tvm.thread_axis("threadIdx.x")) + s[conv].compute_at(s[output], tyx) + + # number of threads + n_tz = cfg["tile_n"].size[2] + n_ty = cfg["tile_f"].size[2] + n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2] + else: + s[output].bind(s[output].fuse(tn, tf), tvm.thread_axis("threadIdx.z")) + s[output].bind(ty, tvm.thread_axis("threadIdx.y")) + s[output].bind(tx, tvm.thread_axis("threadIdx.x")) + s[conv].compute_at(s[output], tx) + + # number of threads + n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2] + n_ty = cfg["tile_y"].size[2] + n_tx = cfg["tile_x"].size[2] # tile and bind reduction axes n, f, y, x, c = s[conv].op.axis From 3de31d1da4a77498539e04647e89dc9abec76591 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 26 Oct 2018 11:00:36 +0800 Subject: [PATCH 5/5] Fix style --- topi/tests/python/test_topi_conv2d_int8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/topi/tests/python/test_topi_conv2d_int8.py b/topi/tests/python/test_topi_conv2d_int8.py index 223b356361fe..7fea0d28a2c4 100644 --- a/topi/tests/python/test_topi_conv2d_int8.py +++ b/topi/tests/python/test_topi_conv2d_int8.py @@ -175,7 +175,7 @@ def test_conv2d_nchw(): # batch > 1 verify_conv2d_NCHWc_int8(7, 32, 149, 32, 3, 1, 0) verify_conv2d_NCHWc_int8(8, 32, 149, 32, 3, 1, 0) - verify_conv2d_NCHWc_int8(32, 32, 149, 32, 3, 1, 0) + verify_conv2d_NCHWc_int8(32, 32, 149, 32, 3, 1, 0) if __name__ == "__main__": test_conv2d_nchw()