diff --git a/python/tvm/relay/op/strategy/mali.py b/python/tvm/relay/op/strategy/mali.py index 6c6440e486f1..d38fe0d82758 100644 --- a/python/tvm/relay/op/strategy/mali.py +++ b/python/tvm/relay/op/strategy/mali.py @@ -73,36 +73,39 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target): elif layout == "NHWC": assert kernel_layout == "HWIO" if not is_auto_scheduler_enabled(): - raise RuntimeError( - "conv2d NHWC layout is not enabled for mali without auto_scheduler." - ) - strategy.add_implementation( - wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True), - naive_schedule, - name="conv2d_nhwc.mali", - ) - is_winograd_applicable = False - if len(kernel.shape) == 4: - kernel_h, kernel_w, _, _ = get_const_tuple(kernel.shape) - is_winograd_applicable = ( - "float" in data.dtype - and "float" in kernel.dtype - and kernel_h == 3 - and kernel_w == 3 - and stride_h == 1 - and stride_w == 1 - and dilation_h == 1 - and dilation_w == 1 + strategy.add_implementation( + wrap_compute_conv2d(topi.mali.conv2d_nhwc_spatial_pack), + wrap_topi_schedule(topi.mali.schedule_conv2d_nhwc_spatial_pack), + name="conv2d_nhwc_spatial_pack.mali", ) - if is_winograd_applicable: + else: strategy.add_implementation( - wrap_compute_conv2d( - topi.nn.conv2d_winograd_nhwc, need_auto_scheduler_layout=True - ), - naive_schedule, # this implementation should never be picked by autotvm - name="conv2d_nhwc.winograd", - plevel=15, + wrap_compute_conv2d(topi.nn.conv2d_nhwc, need_auto_scheduler_layout=True), + naive_schedule, + name="conv2d_nhwc.mali", ) + is_winograd_applicable = False + if len(kernel.shape) == 4: + kernel_h, kernel_w, _, _ = get_const_tuple(kernel.shape) + is_winograd_applicable = ( + "float" in data.dtype + and "float" in kernel.dtype + and kernel_h == 3 + and kernel_w == 3 + and stride_h == 1 + and stride_w == 1 + and dilation_h == 1 + and dilation_w == 1 + ) + if is_winograd_applicable: + strategy.add_implementation( + wrap_compute_conv2d( + topi.nn.conv2d_winograd_nhwc, need_auto_scheduler_layout=True + ), + naive_schedule, # this implementation should never be picked by autotvm + name="conv2d_nhwc.winograd", + plevel=15, + ) else: raise RuntimeError("Unsupported conv2d layout {} for mali".format(layout)) diff --git a/python/tvm/topi/arm_cpu/conv2d_spatial_pack.py b/python/tvm/topi/arm_cpu/conv2d_spatial_pack.py index f4cd9d899b73..b4076d607a82 100644 --- a/python/tvm/topi/arm_cpu/conv2d_spatial_pack.py +++ b/python/tvm/topi/arm_cpu/conv2d_spatial_pack.py @@ -247,7 +247,7 @@ def schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec, conv, output return s -def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype): +def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2): """Spatial pack compute for Conv2d NHWC""" out_dtype = out_dtype or data.dtype @@ -276,9 +276,16 @@ def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_ n, oc, oh, ow = cfg.axis(N), cfg.axis(OC), cfg.axis(OH), cfg.axis(OW) ic, kh, kw = cfg.reduce_axis(IC), cfg.reduce_axis(KH), cfg.reduce_axis(KW) - oco, oci = cfg.define_split("tile_co", oc, num_outputs=2) - oho, ohi = cfg.define_split("tile_oh", oh, num_outputs=2) - owo, owi = cfg.define_split("tile_ow", ow, num_outputs=2) + if num_tile == 2: # for arm cpu + oco, oci = cfg.define_split("tile_co", oc, num_outputs=2) + oho, ohi = cfg.define_split("tile_oh", oh, num_outputs=2) + owo, owi = cfg.define_split("tile_ow", ow, num_outputs=2) + elif num_tile == 3: # for mali gpu + oco, _, oci = cfg.define_split("tile_co", oc, num_outputs=3) + oho, _, ohi = cfg.define_split("tile_oh", oh, num_outputs=3) + owo, _, owi = cfg.define_split("tile_ow", ow, num_outputs=3) + else: + raise RuntimeError("Invalid num_tile") cfg.define_reorder( "reorder_conv", diff --git a/python/tvm/topi/mali/conv2d.py b/python/tvm/topi/mali/conv2d.py index 52fe011a70e9..f3ef55b9a30c 100644 --- a/python/tvm/topi/mali/conv2d.py +++ b/python/tvm/topi/mali/conv2d.py @@ -30,6 +30,7 @@ # reuse some compute declarations from ARM CPU from ..arm_cpu.conv2d_spatial_pack import conv2d_spatial_pack_nchw +from ..arm_cpu.conv2d_spatial_pack import conv2d_spatial_pack_nhwc logger = logging.getLogger("topi") @@ -95,37 +96,59 @@ def schedule_conv2d_nchw_spatial_pack(cfg, outs): def _callback(op): # schedule conv2d if "spatial_conv2d_output" in op.tag: - output = op.output(0) - conv = op.input_tensors[0] + _schedule_spatial_pack(cfg, s, op, layout="NCHW") + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("conv2d_nhwc_spatial_pack.mali") +def conv2d_nhwc_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d with NHWC layout""" + return conv2d_spatial_pack_nhwc( + cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=3 + ) - data_vec = conv.op.input_tensors[0] - data_pad = data_vec.op.input_tensors[0] - s[data_pad].compute_inline() - kernel_vec = conv.op.input_tensors[1] - if kernel_vec.op.name == "kernel_vec": - kernel = kernel_vec.op.input_tensors[0] - else: - kernel = kernel_vec - if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: - s[kernel].compute_inline() +@autotvm.register_topi_schedule("conv2d_nhwc_spatial_pack.mali") +def schedule_conv2d_nhwc_spatial_pack(cfg, outs): + """Create schedule for conv2d_nhwc""" + s = te.create_schedule([x.op for x in outs]) - _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec) + def _callback(op): + # schedule conv2d + if "spatial_conv_output_NHWC" in op.tag: + _schedule_spatial_pack(cfg, s, op, layout="NHWC") traverse_inline(s, outs[0].op, _callback) return s -def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec): +def _schedule_spatial_pack(cfg, s, op, layout): """schedule the spatial packing for conv2d""" + + assert layout in ("NCHW", "NHWC") + + output = op.output(0) + conv = op.input_tensors[0] + data_vec = conv.op.input_tensors[0] + data_pad = data_vec.op.input_tensors[0] + s[data_pad].compute_inline() + kernel_vec = conv.op.input_tensors[1] + if kernel_vec.op.name == "kernel_vec": + kernel = kernel_vec.op.input_tensors[0] + else: + kernel = kernel_vec + if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() data = s[data_vec].op.input_tensors[0] max_unroll = 16 vec_size = [1, 2, 4, 8, 16] # get tunable parameters (they are defined in compute) - BC, TC, VC = cfg["tile_co"].size - BH, TH, VH = cfg["tile_oh"].size - BW, TW, VW = cfg["tile_ow"].size + _, TC, VC = cfg["tile_co"].size + _, TH, VH = cfg["tile_oh"].size + _, TW, VW = cfg["tile_ow"].size # schedule padding if isinstance(data.op, tvm.te.ComputeOp) and "pad" in data.op.tag: @@ -133,21 +156,29 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec): s[data_pad].compute_inline() # schedule data packing - if isinstance(data_vec.op, tvm.te.ComputeOp) and data_vec.op.name == "data_vec_undilated": - _, h, w, ci, _, _, vh, vw = s[data_vec].op.axis + if layout == "NCHW": + if isinstance(data_vec.op, tvm.te.ComputeOp) and data_vec.op.name == "data_vec_undilated": + _, h, w, ci, _, _, vh, vw = s[data_vec].op.axis + else: + _, h, w, ci, vh, vw = s[data_vec].op.axis + z, y, x, unroll1, unroll2 = h, w, ci, vh, vw else: - _, h, w, ci, vh, vw = s[data_vec].op.axis - tile_and_bind3d(s, data_vec, h, w, ci, 1) - if vh.dom.extent.value < max_unroll: - s[data_vec].unroll(vh) - if vw.dom.extent.value < max_unroll: - s[data_vec].unroll(vw) + if isinstance(data_vec.op, tvm.te.ComputeOp) and data_vec.op.name == "data_vec_undilated": + _, oho, owo, _, _, ic, ohi, owi = s[data_vec].op.axis + else: + _, oho, owo, ohi, owi, ic = s[data_vec].op.axis + z, y, x, unroll1, unroll2 = oho, owo, ohi, ic, owi + tile_and_bind3d(s, data_vec, z, y, x, 1) + if unroll1.dom.extent.value < max_unroll: + s[data_vec].unroll(unroll1) + if unroll2.dom.extent.value < max_unroll: + s[data_vec].unroll(unroll2) if isinstance(kernel_vec.op, tvm.te.ComputeOp) and kernel_vec.name == "kernel_vec": if not autotvm.GLOBAL_SCOPE.in_tuning: max_threads = tvm.target.Target.current(allow_none=False).max_num_threads - co, ci, kh, kw, vc = s[kernel_vec].op.axis - fused = s[kernel_vec].fuse(co, ci, kh, kw, vc) + ax1, ax2, ax3, ax4, ax5 = s[kernel_vec].op.axis + fused = s[kernel_vec].fuse(ax1, ax2, ax3, ax4, ax5) fused, vec = s[kernel_vec].split(fused, VC) bb, tt = s[kernel_vec].split(fused, max_threads) s[kernel_vec].bind(bb, te.thread_axis("blockIdx.x")) @@ -156,25 +187,37 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec): s[kernel_vec].vectorize(vec) # schedule convolution - n, c, h, w, vh, vw, vc = s[conv].op.axis - kc, kh, kw = s[conv].op.reduce_axis - - cfg["reorder_0"].apply(s, conv, [n, c, h, w, kc, kh, kw, vh, vw, vc]) - tile_and_bind3d(s, conv, c, h, w, TC, TH, TW) - + ic, kh, kw = s[conv].op.reduce_axis + if layout == "NCHW": + kh_dim, kw_dim = kernel_vec.shape[2], kernel_vec.shape[3] + else: + kh_dim, kw_dim = kernel_vec.shape[0], kernel_vec.shape[1] cfg["ann_reduce"].apply( s, conv, [kh, kw], - axis_lens=[get_const_int(kernel_vec.shape[2]), get_const_int(kernel_vec.shape[3])], + axis_lens=[get_const_int(kh_dim), get_const_int(kw_dim)], max_unroll=max_unroll, ) + if layout == "NCHW": + n, c, h, w, vh, vw, vc = s[conv].op.axis + cfg["reorder_0"].apply(s, conv, [n, c, h, w, ic, kh, kw, vh, vw, vc]) + tile_and_bind3d(s, conv, c, h, w, TC, TH, TW) + unroll_vec_axes = [vh, vw, vc] + axis_lens = [VH, VW, VC] + else: + n, oho, owo, oco, ohi, owi, oci = s[conv].op.axis + cfg["reorder_conv"].apply(s, conv, [n, oho, owo, oco, kh, kw, ic, ohi, owi, oci]) + tile_and_bind3d(s, conv, oho, owo, oco, TH, TW, TC) + unroll_vec_axes = [ohi, owi, oci] + axis_lens = [VH, VW, VC] + cfg["ann_spatial"].apply( s, conv, - [vh, vw, vc], - axis_lens=[VH, VW, VC], + unroll_vec_axes, + axis_lens, max_unroll=max_unroll, vec_size=vec_size, cfg=cfg, @@ -184,9 +227,12 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec): if output.op not in s.outputs: # has bias s[output].compute_inline() output = s.outputs[0] - - _, co, oh, ow = s[output].op.axis - tile_and_bind3d(s, output, co, oh, ow, TC, TH, TW) + if layout == "NCHW": + _, co, oh, ow = s[output].op.axis + tile_and_bind3d(s, output, co, oh, ow, TC, TH, TW) + else: + _, oh, ow, co = s[output].op.axis + tile_and_bind3d(s, output, oh, ow, co, TH, TW, TC) return s diff --git a/tests/python/topi/python/test_topi_conv2d_nhwc.py b/tests/python/topi/python/test_topi_conv2d_nhwc.py index cdb7c0e8d4aa..a34f4d722cbd 100644 --- a/tests/python/topi/python/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/python/test_topi_conv2d_nhwc.py @@ -34,6 +34,10 @@ topi.arm_cpu.conv2d_nhwc_spatial_pack, topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack, ), + "mali": ( + topi.mali.conv2d_nhwc_spatial_pack, + topi.mali.schedule_conv2d_nhwc_spatial_pack, + ), "hls": (topi.nn.conv2d_nhwc, topi.hls.schedule_conv2d_nhwc), }