Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 36 additions & 20 deletions topi/python/topi/cuda/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..nn.conv2d import conv2d_NCHWc_int8_prepacked
from ..nn.pad import pad
from ..nn.util import get_pad_tuple
from ..util import get_const_tuple, get_const_int, traverse_inline
from ..util import get_const_tuple, traverse_inline


def _conv2d_NCHWc_int8_arg_to_workload(data, kernel, stride, padding, out_dtype):
Expand Down Expand Up @@ -183,15 +183,14 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed):
_schedule_injective(packed_data.op, s)
_schedule_injective(packed_kernel.op, s)
else:
kernel = packed_data
kernel = packed_kernel

if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()

if pad_data != packed_data:
s[pad_data].compute_inline()

batch = get_const_int(packed_data.shape[0])
if isinstance(stride, int):
stride_h = stride_w = stride
else:
Expand All @@ -210,33 +209,50 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed):

# tile and bind spatial axes
n, f, y, x, c = s[output].op.axis
cfg.define_split("tile_n", cfg.axis(n), num_outputs=4)
cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
cfg.define_split("tile_x", cfg.axis(x), num_outputs=4)

# this is the scope to attach global config inside this kernel
kernel_scope, n = s[output].split(n, nparts=1)

bn, vn, tn, ni = cfg["tile_n"].apply(s, output, n)
bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)

# this is the scope to attach global config inside this kernel
kernel_scope, n = s[output].split(n, nparts=1)

max_block_z = 128
if batch > max_block_z:
_, n = s[output].split(n, factor=max_block_z)
s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
fused_byx = s[output].fuse(by, bx)
s[output].bind(n, tvm.thread_axis("blockIdx.z"))
s[output].reorder(bn, bf, by, bx, vn, vf, vy, vx, tn, tf, ty, tx, ni, fi, yi, xi)
s[output].bind(bn, tvm.thread_axis("blockIdx.z"))
s[output].bind(bf, tvm.thread_axis("blockIdx.y"))
s[output].bind(fused_byx, tvm.thread_axis("blockIdx.x"))
s[output].bind(s[output].fuse(by, bx), tvm.thread_axis("blockIdx.x"))
s[output].bind(vn, tvm.thread_axis("vthread"))
s[output].bind(vf, tvm.thread_axis("vthread"))
s[output].bind(vy, tvm.thread_axis("vthread"))
s[output].bind(vx, tvm.thread_axis("vthread"))
s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))

s[conv].compute_at(s[output], tx)
cfg.define_knob("fuse_yx", [0, 1]) # fuse ty,tx or tn,tf
if cfg["fuse_yx"].val:
s[output].bind(tn, tvm.thread_axis("threadIdx.z"))
s[output].bind(tf, tvm.thread_axis("threadIdx.y"))
tyx = s[output].fuse(ty, tx)
s[output].bind(tyx, tvm.thread_axis("threadIdx.x"))
s[conv].compute_at(s[output], tyx)

# number of threads
n_tz = cfg["tile_n"].size[2]
n_ty = cfg["tile_f"].size[2]
n_tx = cfg["tile_y"].size[2] * cfg["tile_x"].size[2]
else:
s[output].bind(s[output].fuse(tn, tf), tvm.thread_axis("threadIdx.z"))
s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
s[conv].compute_at(s[output], tx)

# number of threads
n_tz = cfg["tile_n"].size[2] * cfg["tile_f"].size[2]
n_ty = cfg["tile_y"].size[2]
n_tx = cfg["tile_x"].size[2]

# tile and bind reduction axes
n, f, y, x, c = s[conv].op.axis
Expand Down Expand Up @@ -272,9 +288,9 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output, pre_computed):
fused = s[load].fuse(n, f, y, x, oc_chunk)
s[load].vectorize(c)

fused, tx = s[load].split(fused, factor=cfg["tile_x"].size[2])
fused, ty = s[load].split(fused, factor=cfg["tile_y"].size[2])
fused, tz = s[load].split(fused, factor=cfg["tile_f"].size[2])
fused, tx = s[load].split(fused, factor=n_tx)
fused, ty = s[load].split(fused, factor=n_ty)
fused, tz = s[load].split(fused, factor=n_tz)
s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
Expand Down
4 changes: 4 additions & 0 deletions topi/tests/python/test_topi_conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ def test_conv2d_nchw():
verify_conv2d_NCHWc_int8(1, 2048, 8, 192, 1, 1, 0)
verify_conv2d_NCHWc_int8(1, 1024, 19, 84, 3, 1, 1)

# batch > 1
verify_conv2d_NCHWc_int8(7, 32, 149, 32, 3, 1, 0)
verify_conv2d_NCHWc_int8(8, 32, 149, 32, 3, 1, 0)
verify_conv2d_NCHWc_int8(32, 32, 149, 32, 3, 1, 0)

if __name__ == "__main__":
test_conv2d_nchw()