From 407bd6d94bbad8f016ec4e02a1b156fa32e5a592 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Thu, 9 Jul 2020 14:47:22 +0100 Subject: [PATCH 01/11] Improve NHWC depthwise convolution for aarch64 We created a default schedule (no auto-tuning or tensorization) named depthwise_conv2d_nhwc which does a decent job at optimizing depthwise for NHWC layouts (on aarch64). Change-Id: I01e32903f6c1950623f33eae18484e70244fe0af --- python/tvm/relay/op/strategy/arm_cpu.py | 8 +- python/tvm/topi/arm_cpu/depthwise_conv2d.py | 125 +++++++++++++++++- .../topi/python/test_topi_depthwise_conv2d.py | 4 +- 3 files changed, 130 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 8143cc56495a..1c36d39ce4cf 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -167,11 +167,11 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): name="depthwise_conv2d_nchw.x86") elif layout == "NHWC": assert kernel_layout == "HWOI" - logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.") + #logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.") strategy.add_implementation( - wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), - wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc), - name="depthwise_conv2d_nhwc.generic") + wrap_compute_conv2d(topi.arm_cpu.compute_depthwise_conv2d_nhwc), + wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc), + name="depthwise_conv2d_nhwc.arm_cpu") else: raise RuntimeError("Unsupported depthwise_conv2d layout {} for arm cpu". format(layout)) diff --git a/python/tvm/topi/arm_cpu/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/depthwise_conv2d.py index 802b3df19530..f36525d6669e 100644 --- a/python/tvm/topi/arm_cpu/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/depthwise_conv2d.py @@ -31,7 +31,6 @@ def depthwise_conv2d_nchw(_, data, kernel, strides, padding, dilation, out_dtype """Compute depthwise_conv2d with NCHW layout""" return nn.depthwise_conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype) - @autotvm.register_topi_schedule("depthwise_conv2d_nchw.arm_cpu") def schedule_depthwise_conv2d_nchw(cfg, outs): """Schedule depthwise conv2d @@ -181,6 +180,130 @@ def depthwise_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dila return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2) +@autotvm.register_topi_compute("depthwise_conv2d_nhwc.arm_cpu") +def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, out_dtype): + """TOPI compute callback for depthwise_conv2d nhwc + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + data : tvm.te.Tensor + 4-D with shape [batch, in_height, in_width, in_channel] + + kernel : tvm.te.Tensor + 4-D with shape [filter_height, filter_width, in_channel, channel_multiplier] + + strides : list of two ints + [stride_height, stride_width] + + padding : list of two ints + [pad_height, pad_width] + + dilation : list of two ints + [dilation_height, dilation_width] + + out_dtype: str + The output type. This is used for mixed precision. + + Returns + ------- + output : tvm.te.Tensor + 4-D with shape [batch, out_height, out_width, out_channel] + """ + + out_dtype = out_dtype or data.dtype + + N, IH, IW, IC = get_const_tuple(data.shape) + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + KH, KW, IC, channel_multiplier = get_const_tuple(kernel.shape) + + dilated_kernel_h = (KH - 1) * dilation_h + 1 + dilated_kernel_w = (KW - 1) * dilation_w + 1 + + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) + + OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1 + OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 + + if pad_top or pad_left: + data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], + name="data_pad") + else: + data_pad = data + + output_shape = (N, OH, OW, IC*channel_multiplier) + + idxdiv = tvm.tir.indexdiv + idxmod = tvm.tir.indexmod + + reduce_h = te.reduce_axis((0, KH), name='reduce_h') + reduce_w = te.reduce_axis((0, KW), name='reduce_w') + + out = te.compute(output_shape, lambda n, h, w, c: + te.sum(data_pad[n, + HSTR*h+dilation_h*reduce_h, + w*WSTR+reduce_w*dilation_w, + idxdiv(c, channel_multiplier)].astype(out_dtype) * + kernel[reduce_h, + reduce_w, + idxdiv(c, channel_multiplier), + idxmod(c, channel_multiplier)].astype(out_dtype), + axis=[reduce_h, reduce_w]), + name='depthwise_conv2d_nhwc_output') + + return out + +@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu") +def schedule_depthwise_conv2d_nhwc(_, outs): + """Create the schedule for depthwise_conv2d_nchw_spatial_pack""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + out = outs[0] + + def schedule_conv(conv): + n, w, h, c = conv.op.axis + r_h, r_w = conv.op.reduce_axis + co, ci = s[conv].split(c, 8) + wo, wi = s[conv].split(w, 2) + ho, hi = s[conv].split(h, 2) + + s[conv].reorder(n, wo, ho, co, wi, hi, r_h, r_w, ci) + s[conv].parallel(wo) + s[conv].vectorize(ci) + + def schedule_conv_out(out): + n, h, w, c = out.op.axis + co, ci = s[out].split(c, 8) + wo, wi = s[out].split(w, 2) + ho, hi = s[out].split(h, 2) + ci_outer, ci_inner = s[out].split(ci, 4) + s[out].reorder(n, wo, ho, co, wi, hi) + s[out].vectorize(ci_inner) + compute_at_axis = hi + s[out].parallel(wo) + return compute_at_axis + + def _callback(op): + if op.name == 'depthwise_conv2d_nhwc_output': + conv = op.output(0) + if conv != out: + compute_at_axis = schedule_conv_out(out) + schedule_conv(conv) + s[conv].compute_at(s[out], compute_at_axis) + else: + schedule_conv(out) + + traverse_inline(s, outs[0].op, _callback) + return s @autotvm.register_topi_schedule("depthwise_conv2d_nchw_spatial_pack.arm_cpu") def schedule_depthwise_conv2d_nchw_spatial_pack(cfg, outs): diff --git a/tests/python/topi/python/test_topi_depthwise_conv2d.py b/tests/python/topi/python/test_topi_depthwise_conv2d.py index 397861713f73..d5a68b94aba8 100644 --- a/tests/python/topi/python/test_topi_depthwise_conv2d.py +++ b/tests/python/topi/python/test_topi_depthwise_conv2d.py @@ -40,6 +40,7 @@ _depthwise_conv2d_nhwc_implement = { "generic": (topi.nn.depthwise_conv2d_nhwc, topi.generic.schedule_depthwise_conv2d_nhwc), + "arm_cpu": (topi.arm_cpu.compute_depthwise_conv2d_nhwc, topi.arm_cpu.schedule_depthwise_conv2d_nhwc), "gpu": (topi.nn.depthwise_conv2d_nhwc, topi.cuda.schedule_depthwise_conv2d_nhwc), } @@ -385,8 +386,7 @@ def test_depthwise_conv2d(): depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "VALID") depthwise_conv2d_with_workload_nhwc(4, 256, 64, 2, 5, 2, "VALID") # dilation = 2 - # disabled because it uses too large shared memory on cuda - # depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME", dilation=2) + depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME", dilation=2) # NCHW[x]c depthwise_conv2d_with_workload_NCHWc(1, 728, 32, 1, 3, 1, "SAME") From 0385a61aa4197d9def91cb82187d7815fe1b410a Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Tue, 21 Jul 2020 10:28:35 +0100 Subject: [PATCH 02/11] Add tuning knobs in depthwise schedule Change-Id: I15080e7f12b16e6c6aba99a04e42023845eeabf1 --- python/tvm/topi/arm_cpu/depthwise_conv2d.py | 61 +++++++++++++++------ 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/python/tvm/topi/arm_cpu/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/depthwise_conv2d.py index f36525d6669e..701ec6ac2e3a 100644 --- a/python/tvm/topi/arm_cpu/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/depthwise_conv2d.py @@ -20,6 +20,7 @@ import tvm from tvm import te from tvm import autotvm +from tvm.autotvm.task.space import SplitEntity, AnnotateEntity from .. import nn from ..util import traverse_inline, get_const_tuple, get_const_int @@ -259,46 +260,70 @@ def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, o idxmod(c, channel_multiplier)].astype(out_dtype), axis=[reduce_h, reduce_w]), name='depthwise_conv2d_nhwc_output') - return out @autotvm.register_topi_schedule("depthwise_conv2d_nhwc.arm_cpu") -def schedule_depthwise_conv2d_nhwc(_, outs): +def schedule_depthwise_conv2d_nhwc(cfg, outs): """Create the schedule for depthwise_conv2d_nchw_spatial_pack""" outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) out = outs[0] + ##### space definition begin ##### + n, h, w, c = s[out].op.axis + cfg.define_split('tile_c', c, num_outputs=2) + _, hi = cfg.define_split('tile_h', h, num_outputs=2) + _, wi = cfg.define_split('tile_w', w, num_outputs=2) + cfg.define_annotate('locate_output', [hi, wi], 'locate_cache', num_anchor=1) + + # fallback support + if cfg.is_fallback: + cfg['tile_c'] = SplitEntity([-1, 8]) + cfg['tile_h'] = SplitEntity([-1, 2]) + cfg['tile_w'] = SplitEntity([-1, 2]) + cfg['locate_output'] = AnnotateEntity([1]) + ##### space definition end ##### + def schedule_conv(conv): + conv_data = conv.op.input_tensors[0] + if conv_data.name == "data_pad": + s[conv_data].compute_inline() + n, w, h, c = conv.op.axis r_h, r_w = conv.op.reduce_axis - co, ci = s[conv].split(c, 8) - wo, wi = s[conv].split(w, 2) - ho, hi = s[conv].split(h, 2) + ho, hi = cfg['tile_h'].apply(s, conv, h) + wo, wi = cfg['tile_w'].apply(s, conv, w) + co, ci = cfg['tile_c'].apply(s, conv, c) - s[conv].reorder(n, wo, ho, co, wi, hi, r_h, r_w, ci) - s[conv].parallel(wo) + s[conv].reorder(n, ho, wo, co, hi, wi, r_h, r_w, ci) + fused_n_ho = s[conv].fuse(n, ho) + s[conv].parallel(fused_n_ho) s[conv].vectorize(ci) def schedule_conv_out(out): n, h, w, c = out.op.axis - co, ci = s[out].split(c, 8) - wo, wi = s[out].split(w, 2) - ho, hi = s[out].split(h, 2) - ci_outer, ci_inner = s[out].split(ci, 4) - s[out].reorder(n, wo, ho, co, wi, hi) - s[out].vectorize(ci_inner) - compute_at_axis = hi - s[out].parallel(wo) - return compute_at_axis + co, ci = cfg['tile_c'].apply(s, out, c) + wo, wi = cfg['tile_w'].apply(s, out, w) + ho, hi = cfg['tile_h'].apply(s, out, h) + + if out.dtype in ['int8', 'uint8']: + # In case of quantized convolution further split the channel in batches of 4 elements + # so that we can use arm intrinsics to run fixed_point_multiplication + ci_outer, ci_inner = s[out].split(ci, 4) + s[out].reorder(n, ho, wo, co, hi, wi) + s[out].vectorize(ci_inner) + + fused_n_ho = s[out].fuse(n, ho) + s[out].parallel(fused_n_ho) + return hi, wi def _callback(op): if op.name == 'depthwise_conv2d_nhwc_output': conv = op.output(0) if conv != out: - compute_at_axis = schedule_conv_out(out) + hi, wi = schedule_conv_out(out) schedule_conv(conv) - s[conv].compute_at(s[out], compute_at_axis) + cfg['locate_output'].apply(s, out, [hi, wi], source=[[conv]]) else: schedule_conv(out) From edc6c1b7f70d1ae578a673ebf8aba494dacd0a6a Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Wed, 22 Jul 2020 14:42:01 +0100 Subject: [PATCH 03/11] Introduce padding policy Change-Id: If12a6d05dce9153861550ddef1ee5216809dd1e1 --- python/tvm/topi/arm_cpu/depthwise_conv2d.py | 31 ++++++++++++++------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/python/tvm/topi/arm_cpu/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/depthwise_conv2d.py index 701ec6ac2e3a..99dcc57b5a07 100644 --- a/python/tvm/topi/arm_cpu/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/depthwise_conv2d.py @@ -20,7 +20,7 @@ import tvm from tvm import te from tvm import autotvm -from tvm.autotvm.task.space import SplitEntity, AnnotateEntity +from tvm.autotvm.task.space import SplitEntity, AnnotateEntity, OtherOptionEntity from .. import nn from ..util import traverse_inline, get_const_tuple, get_const_int @@ -235,7 +235,7 @@ def compute_depthwise_conv2d_nhwc(_, data, kernel, strides, padding, dilation, o OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1 OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1 - if pad_top or pad_left: + if pad_top or pad_left or pad_down or pad_right: data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], name="data_pad") else: @@ -286,8 +286,6 @@ def schedule_depthwise_conv2d_nhwc(cfg, outs): def schedule_conv(conv): conv_data = conv.op.input_tensors[0] - if conv_data.name == "data_pad": - s[conv_data].compute_inline() n, w, h, c = conv.op.axis r_h, r_w = conv.op.reduce_axis @@ -295,37 +293,50 @@ def schedule_conv(conv): wo, wi = cfg['tile_w'].apply(s, conv, w) co, ci = cfg['tile_c'].apply(s, conv, c) + if conv_data.name == "data_pad": + # Define a policy for padding computation + cfg.define_knob('data_pad_inline', [1, 2, 3]) + if cfg.is_fallback: + cfg['data_pad_inline'] = OtherOptionEntity(3) + if cfg['data_pad_inline'].val == 1: + s[conv_data].compute_at(s[conv], ho) + if cfg['data_pad_inline'].val == 2: + s[conv_data].compute_at(s[conv], wo) + if cfg['data_pad_inline'].val == 3: + s[conv_data].compute_inline() + s[conv].reorder(n, ho, wo, co, hi, wi, r_h, r_w, ci) fused_n_ho = s[conv].fuse(n, ho) - s[conv].parallel(fused_n_ho) s[conv].vectorize(ci) + return fused_n_ho def schedule_conv_out(out): n, h, w, c = out.op.axis co, ci = cfg['tile_c'].apply(s, out, c) wo, wi = cfg['tile_w'].apply(s, out, w) ho, hi = cfg['tile_h'].apply(s, out, h) + s[out].reorder(n, ho, wo, co, hi, wi) if out.dtype in ['int8', 'uint8']: # In case of quantized convolution further split the channel in batches of 4 elements # so that we can use arm intrinsics to run fixed_point_multiplication ci_outer, ci_inner = s[out].split(ci, 4) - s[out].reorder(n, ho, wo, co, hi, wi) s[out].vectorize(ci_inner) fused_n_ho = s[out].fuse(n, ho) - s[out].parallel(fused_n_ho) - return hi, wi + return hi, wi, fused_n_ho def _callback(op): if op.name == 'depthwise_conv2d_nhwc_output': conv = op.output(0) if conv != out: - hi, wi = schedule_conv_out(out) + hi, wi, p_axis = schedule_conv_out(out) schedule_conv(conv) cfg['locate_output'].apply(s, out, [hi, wi], source=[[conv]]) else: - schedule_conv(out) + p_axis = schedule_conv(out) + + s[out].parallel(p_axis) traverse_inline(s, outs[0].op, _callback) return s From a317d769c6b008eb033c67296f3e9a0aba256848 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Wed, 22 Jul 2020 16:38:38 +0100 Subject: [PATCH 04/11] Vectorize padding Change-Id: I7e2062a40358bf111c0366a449945eb077fb2e30 --- python/tvm/topi/arm_cpu/depthwise_conv2d.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/topi/arm_cpu/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/depthwise_conv2d.py index 99dcc57b5a07..4e3d5b98d83d 100644 --- a/python/tvm/topi/arm_cpu/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/depthwise_conv2d.py @@ -299,8 +299,10 @@ def schedule_conv(conv): if cfg.is_fallback: cfg['data_pad_inline'] = OtherOptionEntity(3) if cfg['data_pad_inline'].val == 1: + s[conv_data].vectorize(list(s[conv_data].op.axis)[-1]) s[conv_data].compute_at(s[conv], ho) if cfg['data_pad_inline'].val == 2: + s[conv_data].vectorize(list(s[conv_data].op.axis)[-1]) s[conv_data].compute_at(s[conv], wo) if cfg['data_pad_inline'].val == 3: s[conv_data].compute_inline() From 30a1204b7f746b0f26856c300cfd47dbea1fd52b Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Thu, 23 Jul 2020 10:24:49 +0100 Subject: [PATCH 05/11] Legalize depthwise convolution (2x improvement) and fix tuning issue Change-Id: I4b82c58b167e40b0b7747d28293bbb488c505dd9 --- python/tvm/relay/qnn/op/legalizations.py | 8 +++++++- python/tvm/topi/arm_cpu/depthwise_conv2d.py | 9 ++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index af5072ef74cd..62bee302984d 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -248,7 +248,13 @@ def is_aarch64_arm(): @qnn_conv2d_legalize.register('arm_cpu') def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types): # ARM prefers the dtypes to be same. - if (is_aarch64_arm() and attrs["data_layout"] == "NHWC") or is_fast_int8_on_arm(): + is_depthwise = relay.op.strategy.is_depthwise_conv2d(types[0].shape, + attrs['data_layout'], + types[1].shape, + attrs['kernel_layout'], + attrs['groups']) + use_int8_on_arm = (not is_depthwise) and is_aarch64_arm() and attrs["data_layout"] == "NHWC" + if use_int8_on_arm or is_fast_int8_on_arm(): return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d) return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d) diff --git a/python/tvm/topi/arm_cpu/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/depthwise_conv2d.py index 4e3d5b98d83d..09386ac00beb 100644 --- a/python/tvm/topi/arm_cpu/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/depthwise_conv2d.py @@ -274,14 +274,14 @@ def schedule_depthwise_conv2d_nhwc(cfg, outs): cfg.define_split('tile_c', c, num_outputs=2) _, hi = cfg.define_split('tile_h', h, num_outputs=2) _, wi = cfg.define_split('tile_w', w, num_outputs=2) - cfg.define_annotate('locate_output', [hi, wi], 'locate_cache', num_anchor=1) + cfg.define_knob('locate_output', [0, 1]) # fallback support if cfg.is_fallback: cfg['tile_c'] = SplitEntity([-1, 8]) cfg['tile_h'] = SplitEntity([-1, 2]) cfg['tile_w'] = SplitEntity([-1, 2]) - cfg['locate_output'] = AnnotateEntity([1]) + cfg['locate_output'] = OtherOptionEntity(1) ##### space definition end ##### def schedule_conv(conv): @@ -334,7 +334,10 @@ def _callback(op): if conv != out: hi, wi, p_axis = schedule_conv_out(out) schedule_conv(conv) - cfg['locate_output'].apply(s, out, [hi, wi], source=[[conv]]) + if cfg['locate_output'].val == 0: + s[conv].compute_at(s[out], hi) + if cfg['locate_output'].val == 1: + s[conv].compute_at(s[out], wi) else: p_axis = schedule_conv(out) From bf976b4b83c559cd2bd6237259c345b82fa9b695 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Thu, 23 Jul 2020 14:33:55 +0100 Subject: [PATCH 06/11] Adding assert on padding Change-Id: Idf8eeaaface5eb7799109cd00f437e404778b9cd --- python/tvm/topi/arm_cpu/depthwise_conv2d.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/topi/arm_cpu/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/depthwise_conv2d.py index 09386ac00beb..34da593b4bde 100644 --- a/python/tvm/topi/arm_cpu/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/depthwise_conv2d.py @@ -294,6 +294,7 @@ def schedule_conv(conv): co, ci = cfg['tile_c'].apply(s, conv, c) if conv_data.name == "data_pad": + assert isinstance(conv_data.op, tvm.te.ComputeOp) # Define a policy for padding computation cfg.define_knob('data_pad_inline', [1, 2, 3]) if cfg.is_fallback: From 69d1f118cf173ec2641cdb642a0119f4c35d1987 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Thu, 23 Jul 2020 15:20:58 +0100 Subject: [PATCH 07/11] Fix python linting Change-Id: Iac16a8daea1268f0eb331fe4ec18a62408106cf9 --- python/tvm/topi/arm_cpu/depthwise_conv2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/arm_cpu/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/depthwise_conv2d.py index 34da593b4bde..07749ee72394 100644 --- a/python/tvm/topi/arm_cpu/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/depthwise_conv2d.py @@ -20,7 +20,7 @@ import tvm from tvm import te from tvm import autotvm -from tvm.autotvm.task.space import SplitEntity, AnnotateEntity, OtherOptionEntity +from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity from .. import nn from ..util import traverse_inline, get_const_tuple, get_const_int From c53865dc04c8a180a876d5a273ff5ac62f104bbc Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Thu, 23 Jul 2020 17:11:29 +0100 Subject: [PATCH 08/11] Removing commented code Change-Id: I1412f22ad9864273d77a7bf38a6768694339b7f0 --- python/tvm/relay/op/strategy/arm_cpu.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 1c36d39ce4cf..0c4edbb410e9 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -167,7 +167,6 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): name="depthwise_conv2d_nchw.x86") elif layout == "NHWC": assert kernel_layout == "HWOI" - #logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.") strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.compute_depthwise_conv2d_nhwc), wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc), From 849e280f90170db76e6ae1f0f4311e71932411b6 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Thu, 23 Jul 2020 22:07:58 +0100 Subject: [PATCH 09/11] Revert test file to make CI pass Change-Id: Ica3eff8f9f0fd4c6f32f7ae80adc922f8b16cec9 --- tests/python/topi/python/test_topi_depthwise_conv2d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/topi/python/test_topi_depthwise_conv2d.py b/tests/python/topi/python/test_topi_depthwise_conv2d.py index d5a68b94aba8..397861713f73 100644 --- a/tests/python/topi/python/test_topi_depthwise_conv2d.py +++ b/tests/python/topi/python/test_topi_depthwise_conv2d.py @@ -40,7 +40,6 @@ _depthwise_conv2d_nhwc_implement = { "generic": (topi.nn.depthwise_conv2d_nhwc, topi.generic.schedule_depthwise_conv2d_nhwc), - "arm_cpu": (topi.arm_cpu.compute_depthwise_conv2d_nhwc, topi.arm_cpu.schedule_depthwise_conv2d_nhwc), "gpu": (topi.nn.depthwise_conv2d_nhwc, topi.cuda.schedule_depthwise_conv2d_nhwc), } @@ -386,7 +385,8 @@ def test_depthwise_conv2d(): depthwise_conv2d_with_workload_nhwc(1, 728, 32, 1, 3, 1, "VALID") depthwise_conv2d_with_workload_nhwc(4, 256, 64, 2, 5, 2, "VALID") # dilation = 2 - depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME", dilation=2) + # disabled because it uses too large shared memory on cuda + # depthwise_conv2d_with_workload_nhwc(1, 728, 64, 1, 3, 1, "SAME", dilation=2) # NCHW[x]c depthwise_conv2d_with_workload_NCHWc(1, 728, 32, 1, 3, 1, "SAME") From ec13395c8e223199039dcdca538499034746dd49 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Thu, 23 Jul 2020 23:37:42 +0100 Subject: [PATCH 10/11] Enabling only arm_cpu tests Change-Id: Icbaafcb39e892a5d1a4685133c1699e4d1a8e07e --- tests/python/topi/python/test_topi_depthwise_conv2d.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/topi/python/test_topi_depthwise_conv2d.py b/tests/python/topi/python/test_topi_depthwise_conv2d.py index 397861713f73..93a166d3e216 100644 --- a/tests/python/topi/python/test_topi_depthwise_conv2d.py +++ b/tests/python/topi/python/test_topi_depthwise_conv2d.py @@ -40,6 +40,7 @@ _depthwise_conv2d_nhwc_implement = { "generic": (topi.nn.depthwise_conv2d_nhwc, topi.generic.schedule_depthwise_conv2d_nhwc), + "arm_cpu": (topi.arm_cpu.compute_depthwise_conv2d_nhwc, topi.arm_cpu.schedule_depthwise_conv2d_nhwc), "gpu": (topi.nn.depthwise_conv2d_nhwc, topi.cuda.schedule_depthwise_conv2d_nhwc), } From 6ef6572a364acf85428de65ecd92d8b9eaf84850 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Wed, 12 Aug 2020 12:01:08 +0100 Subject: [PATCH 11/11] Rebasing Change-Id: Ibb23f1d4e0d0107e4e3b3571437161cdc2ee2909