diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 005eae68b8b7..2d331d0b57c6 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -21,6 +21,7 @@ from tvm import topi from ....target import arm_isa +from ....topi.generic import conv2d as conv2d_generic from .generic import * from .. import op as _op @@ -197,11 +198,19 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): ) elif layout == "NHWC": assert kernel_layout == "HWOI" - strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.compute_depthwise_conv2d_nhwc), - wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc), - name="depthwise_conv2d_nhwc.arm_cpu", - ) + is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm() + if is_aarch64 or "+neon" in target.mattr: + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.compute_depthwise_conv2d_nhwc), + wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc), + name="depthwise_conv2d_nhwc.arm_cpu", + ) + else: + strategy.add_implementation( + wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), + wrap_topi_schedule(conv2d_generic.schedule_depthwise_conv2d_nhwc), + name="depthwise_conv2d_nhwc.generic", + ) else: raise RuntimeError("Unsupported depthwise_conv2d layout {} for arm cpu".format(layout)) else: # group_conv2d diff --git a/python/tvm/topi/generic/conv2d.py b/python/tvm/topi/generic/conv2d.py index 4daa84c29528..3772fdbafe6c 100644 --- a/python/tvm/topi/generic/conv2d.py +++ b/python/tvm/topi/generic/conv2d.py @@ -20,7 +20,7 @@ from tvm import te from tvm import autotvm from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity -from ..utils import get_const_tuple +from ..utils import get_const_tuple, traverse_inline def fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements): @@ -361,3 +361,32 @@ def schedule_conv_NCHWc_cpu_1x1_int8( raise ValueError("Unsupported output ndim: %s" % out_ndim) return s + + +def schedule_depthwise_conv2d_nhwc(outs): + """Create schedule for depthwise conv2d in NHWC layout. + Parameters + ---------- + outs : list[te.tensor.Tensor] + The output tensors. + Returns + ------- + s : tvm.te.schedule.Schedule + The computation schedule for depthwise conv2d. + """ + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + """Traverse operators from computation graph""" + if "depthwise_conv2d_nhwc" in op.tag: + out = outs[0] + depthwise_conv2d_out = op.output(0) + data_pad = depthwise_conv2d_out.op.input_tensors[0] + s[data_pad].compute_inline() + if depthwise_conv2d_out != out: + s[depthwise_conv2d_out].compute_at(s[out], s[out].op.axis[3]) + s[out].fuse(*s[out].op.axis) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/tests/python/topi/python/test_topi_depthwise_conv2d.py b/tests/python/topi/python/test_topi_depthwise_conv2d.py index 27601cd32b89..24c232129c91 100644 --- a/tests/python/topi/python/test_topi_depthwise_conv2d.py +++ b/tests/python/topi/python/test_topi_depthwise_conv2d.py @@ -30,6 +30,7 @@ from tvm.contrib.pickle_memoize import memoize from tvm.topi.nn.depthwise_conv2d import _get_workload from tvm.topi.x86.depthwise_conv2d import _fallback_schedule +from tvm.topi.generic import conv2d as conv2d_generic _depthwise_conv2d_implement = { @@ -53,7 +54,10 @@ ], }, "NHWC": { - "generic": [(topi.nn.depthwise_conv2d_nhwc, topi.generic.schedule_depthwise_conv2d_nhwc)], + "generic": [ + (topi.nn.depthwise_conv2d_nhwc, topi.generic.schedule_depthwise_conv2d_nhwc), + (topi.nn.depthwise_conv2d_nhwc, conv2d_generic.schedule_depthwise_conv2d_nhwc), + ], "arm_cpu": [ ( topi.arm_cpu.compute_depthwise_conv2d_nhwc,