From 858b30cede4a0485d7a38aa7bb5be15dc2009668 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Tue, 18 Jan 2022 12:13:11 -0800 Subject: [PATCH] [TOPI,x86] Improve performance on int8 conv2d on x86 Appended fused operations in cov2d for int8 were computed in a separate loop from the main conv2d computation: ``` for i in ... parallel for j in ... accumulator = 0 for k in .. vectorized_multiply_add(accumulator, data, kernel) out = accumulator for k in .. out = out + fused subsequent ops ``` This patch moves the fused ops one more loop nesting inwards to get ``` for i in ... parallel for j in ... accumulator = 0 for k in .. vectorized_multiply_add(accumulator, data, kernel) out = accumulator + fused subsequent ops ``` On quantized mobilenetv2, this results in approximately a 30% speedup. --- python/tvm/topi/generic/conv2d.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/generic/conv2d.py b/python/tvm/topi/generic/conv2d.py index 640c13f4372f..dc70e0ed89f9 100644 --- a/python/tvm/topi/generic/conv2d.py +++ b/python/tvm/topi/generic/conv2d.py @@ -228,8 +228,8 @@ def schedule_conv_NCHWc_cpu_common_int8( batch, oc_chunk, oh, ow, oc_block = s[O].op.axis ow_chunk, ow_block = s[O].split(ow, factor=reg_n) s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + s[C].compute_at(s[O], ow_block) parallel_axis = s[O].fuse(batch, oc_chunk, oh) - s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) elif out_ndim == 4: @@ -237,8 +237,8 @@ def schedule_conv_NCHWc_cpu_common_int8( ow_chunk, ow_block = s[O].split(ow, factor=reg_n) oc_chunk, oc_block = s[O].split(oc, factor=oc_bn) s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) + s[C].compute_at(s[O], ow_block) parallel_axis = s[O].fuse(batch, oc_chunk, oh) - s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) else: @@ -301,8 +301,8 @@ def schedule_conv_NCHWc_cpu_1x1_int8( s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) s[C].vectorize(oc_block) + s[CC].compute_at(s[C], ow_inner) parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer) - s[CC].compute_at(s[C], parallel_axis) if C == O: s[C].parallel(parallel_axis) @@ -346,8 +346,8 @@ def schedule_conv_NCHWc_cpu_1x1_int8( ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + s[C].compute_at(s[O], ow_inner) parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) - s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) elif out_ndim == 4: @@ -357,8 +357,8 @@ def schedule_conv_NCHWc_cpu_1x1_int8( ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) + s[C].compute_at(s[O], ow_inner) parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer) - s[C].compute_at(s[O], parallel_axis) s[O].vectorize(oc_block) s[O].parallel(parallel_axis) else: