diff --git a/python/tvm/topi/generic/conv2d.py b/python/tvm/topi/generic/conv2d.py index 640c13f4372f..dc70e0ed89f9 100644 --- a/python/tvm/topi/generic/conv2d.py +++ b/python/tvm/topi/generic/conv2d.py @@ -228,8 +228,8 @@ def schedule_conv_NCHWc_cpu_common_int8( batch, oc_chunk, oh, ow, oc_block = s[O].op.axis ow_chunk, ow_block = s[O].split(ow, factor=reg_n) s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + s[C].compute_at(s[O], ow_block) parallel_axis = s[O].fuse(batch, oc_chunk, oh) - s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) elif out_ndim == 4: @@ -237,8 +237,8 @@ def schedule_conv_NCHWc_cpu_common_int8( ow_chunk, ow_block = s[O].split(ow, factor=reg_n) oc_chunk, oc_block = s[O].split(oc, factor=oc_bn) s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + s[C].compute_at(s[O], ow_block) parallel_axis = s[O].fuse(batch, oc_chunk, oh) - s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) else: @@ -301,8 +301,8 @@ def schedule_conv_NCHWc_cpu_1x1_int8( s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) s[C].vectorize(oc_block) + s[CC].compute_at(s[C], ow_inner) parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer) - s[CC].compute_at(s[C], parallel_axis) if C == O: s[C].parallel(parallel_axis) @@ -346,8 +346,8 @@ def schedule_conv_NCHWc_cpu_1x1_int8( ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + s[C].compute_at(s[O], ow_inner) parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) - s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) elif out_ndim == 4: @@ -357,8 +357,8 @@ def schedule_conv_NCHWc_cpu_1x1_int8( ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + s[C].compute_at(s[O], ow_inner) parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) - s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) else: