From d0b4ffade20a64b08631a2f8de439f0a374048e1 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Tue, 2 Feb 2021 16:04:19 -0800 Subject: [PATCH 01/15] Add Adreno device to TVM along with a conv2d schedule utilizing NCHW4c layout and texture memory. --- python/tvm/relay/op/strategy/__init__.py | 1 + python/tvm/relay/op/strategy/adreno.py | 47 +++++ python/tvm/target/target.py | 13 ++ python/tvm/topi/__init__.py | 1 + python/tvm/topi/adreno/__init__.py | 20 ++ python/tvm/topi/adreno/conv2d.py | 255 +++++++++++++++++++++++ 6 files changed, 337 insertions(+) create mode 100644 python/tvm/relay/op/strategy/adreno.py create mode 100644 python/tvm/topi/adreno/__init__.py create mode 100644 python/tvm/topi/adreno/conv2d.py diff --git a/python/tvm/relay/op/strategy/__init__.py b/python/tvm/relay/op/strategy/__init__.py index 8d0543ba30af..699addd827a4 100644 --- a/python/tvm/relay/op/strategy/__init__.py +++ b/python/tvm/relay/op/strategy/__init__.py @@ -21,6 +21,7 @@ from .generic import * from . import x86 +from . import adreno from . import arm_cpu from . import cuda from . import hls diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py new file mode 100644 index 000000000000..f7bdc310726d --- /dev/null +++ b/python/tvm/relay/op/strategy/adreno.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Definition of adreno operator strategy.""" +# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import +import re +from tvm import topi +from .generic import * +from .. import op as _op + +@conv2d_strategy.register("adreno") +def conv2d_strategy_adreno(attrs, inputs, out_type, target): + """conv2d adreno strategy""" + strategy = _op.OpStrategy() + data, kernel = inputs + dilation_h, dilation_w = attrs.get_int_tuple("dilation") + stride_h, stride_w = attrs.get_int_tuple("strides") + groups = attrs.groups + layout = attrs.data_layout + kernel_layout = attrs.kernel_layout + if dilation_h < 1 or dilation_w < 1: + raise ValueError("dilation should be positive value") + + if groups == 1: + if layout == "NCHW4c" and kernel_layout == "OIHW4o": + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.conv2d_nchwc), + wrap_topi_schedule(topi.adreno.schedule_conv2d_nchwc), + name="conv2d_nchwc.opencl", + ) + else: + raise RuntimeError("group_conv2d is not yet supported for adreno") + return strategy + diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 106432cd44f7..d67c0898627b 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -263,6 +263,19 @@ def mali(model="unknown", options=None): opts = _merge_opts(opts, options) return Target(" ".join(["opencl"] + opts)) +def adreno(model="unknown", options=None): + """Returns a Qualcomm GPU target. + + Parameters + ---------- + model: str + The model of this device + options : str or list of str + Additional options + """ + opts = ["-device=adreno", "-model=%s" % model] + opts = _merge_opts(opts, options) + return Target(" ".join(["opencl"] + opts)) def intel_graphics(model="unknown", options=None): """Returns an Intel Graphics target. diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index 9b843ae181fb..81c02067dd3f 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -50,6 +50,7 @@ from . import x86 from . import cuda from . import gpu +from . import adreno from . import arm_cpu from . import mali from . import bifrost diff --git a/python/tvm/topi/adreno/__init__.py b/python/tvm/topi/adreno/__init__.py new file mode 100644 index 000000000000..6217638c4922 --- /dev/null +++ b/python/tvm/topi/adreno/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=redefined-builtin, wildcard-import +"""Qualcomm Adreno GPU specific declaration and schedules.""" +from .conv2d import * diff --git a/python/tvm/topi/adreno/conv2d.py b/python/tvm/topi/adreno/conv2d.py new file mode 100644 index 000000000000..28e903e18391 --- /dev/null +++ b/python/tvm/topi/adreno/conv2d.py @@ -0,0 +1,255 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return +"""conv2d schedule on Qualcomm Adreno GPU""" +import tvm +from tvm import te +from tvm import autotvm + +from tvm.topi import nn +from tvm.topi.utils import simplify +from ..utils import get_const_tuple, traverse_inline + + +@autotvm.register_topi_compute("conv2d_nchwc.opencl") +def conv2d_nchwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"): + """Compute conv2d with NCHWc layout""" + args={"memory" : "texture", "shared" : False} + return compute_conv2d_NCHWc_KCRSk(data, kernel, strides, padding, dilation, out_dtype, args=args) + +@autotvm.register_topi_schedule("conv2d_nchwc.opencl") +def schedule_conv2d_nchwc(cfg, outs): + """Create the schedule for conv2d_nchw""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == "conv2d_nchwc": + args={"memory" : "texture", "shared" : False} + schedule_conv2d_NCHWc_KCRSk(cfg, s, op.output(0), args) + + traverse_inline(s, outs[0].op, _callback) + return s + + +def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dtype=None, args={}): + """Convolution operator in NCHWc layout. """ + + if out_dtype is None: + out_dtype = Input.dtype + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, in_channel_chunk, in_height, in_width, in_channel_block = Input.shape + num_filter_chunk, channel, kernel_h, kernel_w, num_filter_block = Filter.shape + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + # compute graph + pad_before = [0, 0, pad_top, pad_left, 0] + pad_after = [0, 0, pad_down, pad_right, 0] + temp = nn.pad(Input, pad_before, pad_after, name="pad_temp") + + rcc = te.reduce_axis((0, in_channel_chunk), name="rc") + rcb = te.reduce_axis((0, in_channel_block), name="rc") + ry = te.reduce_axis((0, kernel_h), name="ry") + rx = te.reduce_axis((0, kernel_w), name="rx") + + if args["memory"] != None: + # NCHWc x KCRSk + # texture: NCH|W|c + # texture: K|CRS|k + # c = crs//RS + # rs = crs % RS + # r = rs // W == (crs // S) % R + # s = rs % W == crs % S + Filter_tx = te.compute( + (num_filter_chunk, channel * kernel_h * kernel_w, num_filter_block), + lambda ffc, crs, ffb: Filter[ffc, crs // (kernel_h * kernel_w), (crs // kernel_w) % kernel_h, crs % kernel_w, ffb], + name = "packed_filter" + ) + return te.compute( + (batch, num_filter_chunk, out_height, out_width, num_filter_block), + lambda nn, ffc, yy, xx, ffb: te.sum( + temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb].astype( + out_dtype + ) + * Filter_tx[ffc, ((rcc * in_channel_block + rcb)*kernel_h + ry)*kernel_w + rx, ffb].astype(out_dtype), + axis=[rcc, rcb, ry, rx], + ), + tag="conv2d_nchwc", + ) + else: + return te.compute( + (batch, num_filter_chunk, out_height, out_width, num_filter_block), + lambda nn, ffc, yy, xx, ffb: te.sum( + temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb].astype( + out_dtype + ) + * Filter[ffc, rcc * in_channel_block + rcb, ry, rx, ffb].astype(out_dtype), + axis=[rcc, rcb, ry, rx], + ), + tag="conv2d_nchwc", + ) + + +def schedule_conv2d_NCHWc_KCRSk(cfg, s, conv, args={}): + """schedule optimized for batch size = 1""" + ##### space definition begin ##### + n, fc, y, x, fb = s[conv].op.axis + rcc, rcb, ry, rx = s[conv].op.reduce_axis + cfg.define_split("tile_fc", fc, num_outputs=4) + cfg.define_split("tile_y", y, num_outputs=4) + cfg.define_split("tile_x", x, num_outputs=4) + cfg.define_split("tile_rcc", rcc, num_outputs=2) + cfg.define_split("tile_ry", ry, num_outputs=2) + cfg.define_split("tile_rx", rx, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + + target = tvm.target.Target.current() + if target.kind.name in ["nvptx", "rocm"]: + cfg.define_knob("unroll_explicit", [1]) + else: + cfg.define_knob("unroll_explicit", [0, 1]) + ##### space definition end ##### + + if args["memory"] != None: + pad_data, flattened_kernel = s[conv].op.input_tensors + kernel = s[flattened_kernel].op.input_tensors[0] + s[flattened_kernel].compute_inline() + else: + pad_data, kernel = s[conv].op.input_tensors + flattened_kernel = kernel + + s[pad_data].compute_inline() + if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + kernel = flattened_kernel + + if conv.op in s.outputs: + output = conv + OL = s.cache_write(conv, "local") + else: + output = s.outputs[0].output(0) + s[conv].set_scope("local") + OL = conv + + # create cache stage + if args["memory"] != None: + AT = s.cache_read(pad_data, args["memory"], [OL]) + WT = s.cache_read(kernel, args["memory"], [OL]) + def copy_to_texture(stage): + axes = s[stage].op.axis + fused = s[stage].fuse(*axes[:-1]) + block, thread = s[stage].split(fused, factor=32) + s[stage].vectorize(axes[-1]) + s[stage].bind(block, te.thread_axis("blockIdx.x")) + s[stage].bind(thread, te.thread_axis("threadIdx.x")) + copy_to_texture(AT) + copy_to_texture(WT) + + if args["shared"]: + AA = s.cache_read(AT, "shared", [OL]) + WW = s.cache_read(WT, "shared", [OL]) + else: + AA = s.cache_read(pad_data, "shared", [OL]) + WW = s.cache_read(kernel, "shared", [OL]) + + # tile and bind spatial axes + n, fc, y, x, fb = s[output].op.axis + + kernel_scope, n = s[output].split(n, nparts=1) + + bf, vf, tf, fi = cfg["tile_fc"].apply(s, output, fc) + by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) + + bf = s[output].fuse(n, bf) + s[output].bind(bf, te.thread_axis("blockIdx.z")) + s[output].bind(by, te.thread_axis("blockIdx.y")) + s[output].bind(bx, te.thread_axis("blockIdx.x")) + s[output].bind(vf, te.thread_axis("vthread")) + s[output].bind(vy, te.thread_axis("vthread")) + s[output].bind(vx, te.thread_axis("vthread")) + s[output].bind(tf, te.thread_axis("threadIdx.z")) + s[output].bind(ty, te.thread_axis("threadIdx.y")) + s[output].bind(tx, te.thread_axis("threadIdx.x")) + s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi, fb) + s[output].vectorize(fb) + s[OL].compute_at(s[output], tx) + + # tile reduction axes + n, fc, y, x, fb = s[OL].op.axis + + rcc, rcb, ry, rx = s[OL].op.reduce_axis + rco, rci = cfg["tile_rcc"].apply(s, OL, rcc) + ryo, ryi = cfg["tile_ry"].apply(s, OL, ry) + rxo, rxi = cfg["tile_rx"].apply(s, OL, rx) + + # TODO(csullivan): check position of rcb + s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, rcb, n, fc, y, x, fb) + s[OL].vectorize(fb) + s[OL].unroll(rcb) + + if args["memory"] == None or args["shared"]: + s[AA].compute_at(s[OL], rxo) + s[WW].compute_at(s[OL], rxo) + # cooperative fetching + for load in [AA, WW]: + if args["memory"] != None and load == WW: + n, fyx, v = s[load].op.axis + fused = s[load].fuse(n, fyx) + else: + n, f, y, x, v = s[load].op.axis + fused = s[load].fuse(n, f, y, x) + tz, fused = s[load].split(fused, nparts=cfg["tile_fc"].size[2]) + ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) + tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) + s[load].bind(tz, te.thread_axis("threadIdx.z")) + s[load].bind(ty, te.thread_axis("threadIdx.y")) + s[load].bind(tx, te.thread_axis("threadIdx.x")) + s[load].vectorize(v) + + # unroll + s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val) + + N, OCC, OH, OW, OCB = get_const_tuple(output.shape) + if args["memory"] != None: + _, ICKHKW, _ = get_const_tuple(kernel.shape) + else: + _, IC, KH, KW, _ = get_const_tuple(kernel.shape) + ICKHKW = IC*KH*KW + + + if isinstance(N, int): + cfg.add_flop(2 * N * OH * OW * OCC * OCB * ICKHKW) From 606fe8f91a39a1b70462dc20e4fecf36d0fc5ca4 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Fri, 19 Feb 2021 15:31:45 -0800 Subject: [PATCH 02/15] Add float32 accumulator strategy and update topi impl. for conv2d. --- python/tvm/relay/op/strategy/adreno.py | 20 +++++++++-- python/tvm/topi/adreno/conv2d.py | 47 +++++++++++++++----------- 2 files changed, 45 insertions(+), 22 deletions(-) diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py index f7bdc310726d..d455dd927565 100644 --- a/python/tvm/relay/op/strategy/adreno.py +++ b/python/tvm/relay/op/strategy/adreno.py @@ -29,18 +29,34 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): dilation_h, dilation_w = attrs.get_int_tuple("dilation") stride_h, stride_w = attrs.get_int_tuple("strides") groups = attrs.groups - layout = attrs.data_layout + data_layout = attrs.data_layout kernel_layout = attrs.kernel_layout + assert out_type.dtype == "float16", "No float32 input/output tensor support is currently provided for Adreno GPU" if dilation_h < 1 or dilation_w < 1: raise ValueError("dilation should be positive value") if groups == 1: - if layout == "NCHW4c" and kernel_layout == "OIHW4o": + if data_layout == "NCHW" and kernel_layout == "OIHW": + strategy.add_implementation( + wrap_compute_conv2d(topi.mali.conv2d_nchw_spatial_pack), + wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_spatial_pack), + name="conv2d_nchw_spatial_pack.mali", + ) + elif data_layout == "NCHW4c" and kernel_layout == "OIHW4o": strategy.add_implementation( wrap_compute_conv2d(topi.adreno.conv2d_nchwc), wrap_topi_schedule(topi.adreno.schedule_conv2d_nchwc), name="conv2d_nchwc.opencl", + plevel=10 + ) + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.conv2d_nchwc_acc32), + wrap_topi_schedule(topi.adreno.schedule_conv2d_nchwc_acc32), + name="conv2d_nchwc_acc32.opencl", + plevel=20 ) + else: + raise RuntimeError("Layout not supported: ("+data_layout+", "+kernel_layout+") - only support NCHW4c / OIHW4o layouts for conv2d") else: raise RuntimeError("group_conv2d is not yet supported for adreno") return strategy diff --git a/python/tvm/topi/adreno/conv2d.py b/python/tvm/topi/adreno/conv2d.py index 28e903e18391..caab3686484b 100644 --- a/python/tvm/topi/adreno/conv2d.py +++ b/python/tvm/topi/adreno/conv2d.py @@ -26,19 +26,31 @@ @autotvm.register_topi_compute("conv2d_nchwc.opencl") -def conv2d_nchwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"): +def conv2d_nchwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): """Compute conv2d with NCHWc layout""" - args={"memory" : "texture", "shared" : False} + args={"memory" : "texture", "shared" : False, "accumulator" : "float16"} + return compute_conv2d_NCHWc_KCRSk(data, kernel, strides, padding, dilation, out_dtype, args=args) + +@autotvm.register_topi_compute("conv2d_nchwc_acc32.opencl") +def conv2d_nchwc_acc32(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): + """Compute conv2d with NCHWc layout""" + args={"memory" : "texture", "shared" : False, "accumulator" : "float32"} return compute_conv2d_NCHWc_KCRSk(data, kernel, strides, padding, dilation, out_dtype, args=args) @autotvm.register_topi_schedule("conv2d_nchwc.opencl") def schedule_conv2d_nchwc(cfg, outs): + return schedule_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc16") + +@autotvm.register_topi_schedule("conv2d_nchwc_acc32.opencl") +def schedule_conv2d_nchwc_acc32(cfg, outs): + return schedule_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc32") + +def schedule_conv2d_nchwc_impl(cfg, outs, tag): """Create the schedule for conv2d_nchw""" outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) - def _callback(op): - if op.tag == "conv2d_nchwc": + if op.tag == tag: args={"memory" : "texture", "shared" : False} schedule_conv2d_NCHWc_KCRSk(cfg, s, op.output(0), args) @@ -88,42 +100,36 @@ def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dty # NCHWc x KCRSk # texture: NCH|W|c # texture: K|CRS|k - # c = crs//RS - # rs = crs % RS - # r = rs // W == (crs // S) % R - # s = rs % W == crs % S Filter_tx = te.compute( (num_filter_chunk, channel * kernel_h * kernel_w, num_filter_block), lambda ffc, crs, ffb: Filter[ffc, crs // (kernel_h * kernel_w), (crs // kernel_w) % kernel_h, crs % kernel_w, ffb], name = "packed_filter" ) - return te.compute( + conv = te.compute( (batch, num_filter_chunk, out_height, out_width, num_filter_block), lambda nn, ffc, yy, xx, ffb: te.sum( - temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb].astype( - out_dtype - ) - * Filter_tx[ffc, ((rcc * in_channel_block + rcb)*kernel_h + ry)*kernel_w + rx, ffb].astype(out_dtype), + (temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb] + * Filter_tx[ffc, ((rcc * in_channel_block + rcb)*kernel_h + ry)*kernel_w + rx, ffb]).astype(args["accumulator"]), axis=[rcc, rcb, ry, rx], ), tag="conv2d_nchwc", ) else: - return te.compute( + conv = te.compute( (batch, num_filter_chunk, out_height, out_width, num_filter_block), lambda nn, ffc, yy, xx, ffb: te.sum( - temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb].astype( - out_dtype - ) - * Filter[ffc, rcc * in_channel_block + rcb, ry, rx, ffb].astype(out_dtype), + (temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb] + * Filter[ffc, rcc * in_channel_block + rcb, ry, rx, ffb]).astype(args["accumulator"]), axis=[rcc, rcb, ry, rx], ), tag="conv2d_nchwc", ) + return te.compute(conv.shape, lambda n,fc,y,x,fb: conv[n,fc,y,x,fb].astype("float16"), tag="cast_from_acc" + args["accumulator"][-2:]) - -def schedule_conv2d_NCHWc_KCRSk(cfg, s, conv, args={}): +def schedule_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): """schedule optimized for batch size = 1""" + conv = output.op.input_tensors[0] + ##### space definition begin ##### n, fc, y, x, fb = s[conv].op.axis rcc, rcb, ry, rx = s[conv].op.reduce_axis @@ -205,6 +211,7 @@ def copy_to_texture(stage): s[output].bind(tx, te.thread_axis("threadIdx.x")) s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi, fb) s[output].vectorize(fb) + s[OL].compute_at(s[output], tx) # tile reduction axes From d12efb92067e45be4e8ebc7c45b3cf7a26badf5f Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Mon, 22 Feb 2021 15:46:37 -0800 Subject: [PATCH 03/15] Add depthwise conv2d impl. --- python/tvm/relay/op/strategy/adreno.py | 16 +- python/tvm/topi/adreno/conv2d.py | 239 ++++++++++++++++++++++++- 2 files changed, 253 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py index d455dd927565..b7b59c0b16a3 100644 --- a/python/tvm/relay/op/strategy/adreno.py +++ b/python/tvm/relay/op/strategy/adreno.py @@ -58,6 +58,20 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): else: raise RuntimeError("Layout not supported: ("+data_layout+", "+kernel_layout+") - only support NCHW4c / OIHW4o layouts for conv2d") else: - raise RuntimeError("group_conv2d is not yet supported for adreno") + if data_layout == "NCHW4c" and kernel_layout == "OIHW4o": + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nchwc), + wrap_topi_schedule(topi.adreno.schedule_depthwise_conv2d_nchwc), + name="depthwise_conv2d_nchwc.opencl", + plevel=10 + ) + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nchwc_acc32), + wrap_topi_schedule(topi.adreno.schedule_depthwise_conv2d_nchwc_acc32), + name="depthwise_conv2d_nchwc_acc32.opencl", + plevel=20 + ) + else: + raise RuntimeError("Layout not supported: ("+data_layout+", "+kernel_layout+") - only support NCHW4c / OIHW4o layouts for conv2d") return strategy diff --git a/python/tvm/topi/adreno/conv2d.py b/python/tvm/topi/adreno/conv2d.py index caab3686484b..aeb407ccd948 100644 --- a/python/tvm/topi/adreno/conv2d.py +++ b/python/tvm/topi/adreno/conv2d.py @@ -45,6 +45,26 @@ def schedule_conv2d_nchwc(cfg, outs): def schedule_conv2d_nchwc_acc32(cfg, outs): return schedule_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc32") +@autotvm.register_topi_compute("depthwise_conv2d_nchwc.opencl") +def depthwise_conv2d_nchwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): + """Compute depthwise_conv2d with NCHWc layout""" + args={"memory" : "texture", "shared" : False, "accumulator" : "float16"} + return compute_depthwise_conv2d_NCHWc_KCRSk(data, kernel, strides, padding, dilation, out_dtype, args=args) + +@autotvm.register_topi_compute("depthwise_conv2d_nchwc_acc32.opencl") +def depthwise_conv2d_nchwc_acc32(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): + """Compute depthwise_conv2d with NCHWc layout""" + args={"memory" : "texture", "shared" : False, "accumulator" : "float32"} + return compute_depthwise_conv2d_NCHWc_KCRSk(data, kernel, strides, padding, dilation, out_dtype, args=args) + +@autotvm.register_topi_schedule("depthwise_conv2d_nchwc.opencl") +def schedule_depthwise_conv2d_nchwc(cfg, outs): + return schedule_depthwise_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc16") + +@autotvm.register_topi_schedule("depthwise_conv2d_nchwc_acc32.opencl") +def schedule_depthwise_conv2d_nchwc_acc32(cfg, outs): + return schedule_depthwise_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc32") + def schedule_conv2d_nchwc_impl(cfg, outs, tag): """Create the schedule for conv2d_nchw""" outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs @@ -57,7 +77,6 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s - def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dtype=None, args={}): """Convolution operator in NCHWc layout. """ @@ -260,3 +279,221 @@ def copy_to_texture(stage): if isinstance(N, int): cfg.add_flop(2 * N * OH * OW * OCC * OCB * ICKHKW) + + +def schedule_depthwise_conv2d_nchwc_impl(cfg, outs, tag): + """Create the schedule for depthwise conv2d_nchw4c_ohwi4o""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + def _callback(op): + if op.tag == tag: + args={"memory" : "texture", "shared" : False} + schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, op.output(0), args) + + traverse_inline(s, outs[0].op, _callback) + return s + +def compute_depthwise_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dtype=None, args={}): + """Depthwise convolution operator in NCHWc layout. """ + if out_dtype is None: + out_dtype = Input.dtype + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, channel_chunk, in_height, in_width, channel_block = Input.shape + _, channel_multiplier, kernel_h, kernel_w, _ = Filter.shape + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + out_channel_chunk = simplify(channel_chunk * channel_multiplier) + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + # compute graph + pad_before = [0, 0, pad_top, pad_left, 0] + pad_after = [0, 0, pad_down, pad_right, 0] + temp = nn.pad(Input, pad_before, pad_after, name="pad_temp") + + ry = te.reduce_axis((0, kernel_h), name="ry") + rx = te.reduce_axis((0, kernel_w), name="rx") + + + if args["memory"] != None: + # NCHWc x CMRSc = [N,(C//4)M,OH,OW, 4c] + # NCHWc x CMRS + # texture: NCH|W|c + # texture: C|MRS|c + Filter_tx = te.compute( + (channel_chunk, channel_multiplier * kernel_h * kernel_w, channel_block), + lambda ffc, mrs, ffb: Filter[ffc, mrs // (kernel_h * kernel_w), (mrs // kernel_w) % kernel_h, mrs % kernel_w, ffb], + name = "packed_filter" + ) + + conv = te.compute( + (batch, out_channel_chunk, out_height, out_width, channel_block), + lambda nn, ffc, yy, xx, ffb: te.sum( + (temp[nn, ffc//channel_multiplier, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ffb] + * Filter_tx[ffc//channel_multiplier, ((ffc % channel_multiplier) * kernel_h + ry) * kernel_w + rx, ffb]).astype(args["accumulator"]), + axis=[ry, rx], + ), + tag="depthwise_conv2d_nchwc_kcrsk_texture", + ) + else: + conv = te.compute( + (batch, out_channel_chunk, out_height, out_width, channel_block), + lambda nn, ffc, yy, xx, ffb: te.sum( + (temp[nn, ffc//channel_multiplier, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ffb] + * Filter[ffc//channel_multiplier, ffc % channel_multiplier, ry, rx, ffb]).astype(args["accumulator"]), + axis=[ry, rx], + ), + tag="depthwise_conv2d_nchwc_kcrsk", + ) + return te.compute(conv.shape, lambda n,ffc,y,x,ffb: conv[n,ffc,y,x,ffb].astype("float16"), tag="cast_from_acc" + args["accumulator"][-2:]) + +def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): + """schedule optimized for batch size = 1""" + conv = output.op.input_tensors[0] + + ##### space definition begin ##### + n, fc, y, x, fb = s[conv].op.axis + ry, rx = s[conv].op.reduce_axis + cfg.define_split("tile_fc", fc, num_outputs=4) + cfg.define_split("tile_y", y, num_outputs=4) + cfg.define_split("tile_x", x, num_outputs=4) + cfg.define_split("tile_ry", ry, num_outputs=2) + cfg.define_split("tile_rx", rx, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + + target = tvm.target.Target.current() + if target.kind.name in ["nvptx", "rocm"]: + cfg.define_knob("unroll_explicit", [1]) + else: + cfg.define_knob("unroll_explicit", [0, 1]) + ##### space definition end ##### + + if args["memory"] != None: + pad_data, flattened_kernel = s[conv].op.input_tensors + kernel = s[flattened_kernel].op.input_tensors[0] + s[flattened_kernel].compute_inline() + else: + pad_data, kernel = s[conv].op.input_tensors + flattened_kernel = kernel + + s[pad_data].compute_inline() + if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + kernel = flattened_kernel + + if conv.op in s.outputs: + output = conv + OL = s.cache_write(conv, "local") + else: + output = s.outputs[0].output(0) + s[conv].set_scope("local") + OL = conv + + # create cache stage + if args["memory"] != None: + AT = s.cache_read(pad_data, args["memory"], [OL]) + WT = s.cache_read(kernel, args["memory"], [OL]) + def copy_to_texture(stage): + axes = s[stage].op.axis + fused = s[stage].fuse(*axes[:-1]) + block, thread = s[stage].split(fused, factor=32) + s[stage].vectorize(axes[-1]) + s[stage].bind(block, te.thread_axis("blockIdx.x")) + s[stage].bind(thread, te.thread_axis("threadIdx.x")) + copy_to_texture(AT) + copy_to_texture(WT) + + if args["shared"]: + AA = s.cache_read(AT, "shared", [OL]) + WW = s.cache_read(WT, "shared", [OL]) + else: + AA = s.cache_read(pad_data, "shared", [OL]) + WW = s.cache_read(kernel, "shared", [OL]) + + # tile and bind spatial axes + n, fc, y, x, fb = s[output].op.axis + + kernel_scope, n = s[output].split(n, nparts=1) + + bf, vf, tf, fi = cfg["tile_fc"].apply(s, output, fc) + by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) + + bf = s[output].fuse(n, bf) + s[output].bind(bf, te.thread_axis("blockIdx.z")) + s[output].bind(by, te.thread_axis("blockIdx.y")) + s[output].bind(bx, te.thread_axis("blockIdx.x")) + s[output].bind(vf, te.thread_axis("vthread")) + s[output].bind(vy, te.thread_axis("vthread")) + s[output].bind(vx, te.thread_axis("vthread")) + s[output].bind(tf, te.thread_axis("threadIdx.z")) + s[output].bind(ty, te.thread_axis("threadIdx.y")) + s[output].bind(tx, te.thread_axis("threadIdx.x")) + s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi, fb) + s[output].vectorize(fb) + + s[OL].compute_at(s[output], tx) + + # tile reduction axes + n, fc, y, x, fb = s[OL].op.axis + + ry, rx = s[OL].op.reduce_axis + ryo, ryi = cfg["tile_ry"].apply(s, OL, ry) + rxo, rxi = cfg["tile_rx"].apply(s, OL, rx) + + s[OL].reorder(ryo, rxo, ryi, rxi, n, fc, y, x, fb) + s[OL].vectorize(fb) + #s[OL].unroll() + + if args["memory"] == None or args["shared"]: + s[AA].compute_at(s[OL], rxo) + s[WW].compute_at(s[OL], rxo) + # cooperative fetching + for load in [AA, WW]: + if args["memory"] != None and load == WW: + n, fyx, v = s[load].op.axis + fused = s[load].fuse(n, fyx) + else: + n, f, y, x, v = s[load].op.axis + fused = s[load].fuse(n, f, y, x) + tz, fused = s[load].split(fused, nparts=cfg["tile_fc"].size[2]) + ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) + tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) + s[load].bind(tz, te.thread_axis("threadIdx.z")) + s[load].bind(ty, te.thread_axis("threadIdx.y")) + s[load].bind(tx, te.thread_axis("threadIdx.x")) + s[load].vectorize(v) + + # unroll + s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val) + + N, OCC, OH, OW, OCB = get_const_tuple(output.shape) + # OC = OCC * OCB = IC * M + # M = OC // IC == (OCC * OCB) // ICC * ICB + if args["memory"] != None: + ICC, MKHKW, ICB = get_const_tuple(kernel.shape) + M = (OCC * OCB) // (ICC * ICB) + KHKW = MKHKW // M + else: + ICC, M, KH, KW, ICB = get_const_tuple(kernel.shape) + KHKW = KH*KW + + if isinstance(N, int): + cfg.add_flop(2 * N * OH * OW * OCC * OCB * KHKW) From e119c88bed91e8372e256a5373e17465da4255b9 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Tue, 23 Feb 2021 13:46:38 -0800 Subject: [PATCH 04/15] Support injective fusion in Adreno conv2d and depthwise conv2d. --- python/tvm/topi/adreno/conv2d.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/python/tvm/topi/adreno/conv2d.py b/python/tvm/topi/adreno/conv2d.py index aeb407ccd948..b2596d24fd9c 100644 --- a/python/tvm/topi/adreno/conv2d.py +++ b/python/tvm/topi/adreno/conv2d.py @@ -180,10 +180,20 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): s[kernel].compute_inline() kernel = flattened_kernel + # conv only if conv.op in s.outputs: output = conv OL = s.cache_write(conv, "local") + # conv -> output (e.g. when casting conv output) + elif output.op in s.outputs: + output = s.outputs[0].output(0) + s[conv].set_scope("local") + OL = conv + # conv -> injective -> ... -> injective -> output else: + # Explicitly mark the output cast to be computed inline + # the other injective ops are inlined via traverse_inline. + s[output].compute_inline() output = s.outputs[0].output(0) s[conv].set_scope("local") OL = conv @@ -397,10 +407,20 @@ def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): s[kernel].compute_inline() kernel = flattened_kernel + # conv only if conv.op in s.outputs: output = conv OL = s.cache_write(conv, "local") + # conv -> output (e.g. when casting conv output) + elif output.op in s.outputs: + output = s.outputs[0].output(0) + s[conv].set_scope("local") + OL = conv + # conv -> injective -> ... -> injective -> output else: + # Explicitly mark the output cast to be computed inline + # the other injective ops are inlined via traverse_inline. + s[output].compute_inline() output = s.outputs[0].output(0) s[conv].set_scope("local") OL = conv From 6ad0f5a7c7cc57f5d800cb3859f26e9a0e021b91 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Tue, 23 Feb 2021 13:55:18 -0800 Subject: [PATCH 05/15] Distinguish between depthwise and general group convolution in adreno strategies. --- python/tvm/relay/op/strategy/adreno.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py index b7b59c0b16a3..2a1f8c6b846e 100644 --- a/python/tvm/relay/op/strategy/adreno.py +++ b/python/tvm/relay/op/strategy/adreno.py @@ -57,7 +57,7 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): ) else: raise RuntimeError("Layout not supported: ("+data_layout+", "+kernel_layout+") - only support NCHW4c / OIHW4o layouts for conv2d") - else: + elif is_depthwise_conv2d(data.shape, data_layout, kernel.shape, kernel_layout, groups): if data_layout == "NCHW4c" and kernel_layout == "OIHW4o": strategy.add_implementation( wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nchwc), @@ -73,5 +73,7 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): ) else: raise RuntimeError("Layout not supported: ("+data_layout+", "+kernel_layout+") - only support NCHW4c / OIHW4o layouts for conv2d") + else: + raise RuntimeError("General group convolution is not currently supported") return strategy From 9404a5fa920808acc07cda6726c1f23ec7b8db90 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Tue, 23 Feb 2021 14:14:55 -0800 Subject: [PATCH 06/15] Schedule suffix: .opencl -> .image2d --- python/tvm/relay/op/strategy/adreno.py | 8 ++++---- python/tvm/topi/adreno/conv2d.py | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py index 2a1f8c6b846e..faff8ea4846f 100644 --- a/python/tvm/relay/op/strategy/adreno.py +++ b/python/tvm/relay/op/strategy/adreno.py @@ -46,13 +46,13 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv2d(topi.adreno.conv2d_nchwc), wrap_topi_schedule(topi.adreno.schedule_conv2d_nchwc), - name="conv2d_nchwc.opencl", + name="conv2d_nchwc.image2d", plevel=10 ) strategy.add_implementation( wrap_compute_conv2d(topi.adreno.conv2d_nchwc_acc32), wrap_topi_schedule(topi.adreno.schedule_conv2d_nchwc_acc32), - name="conv2d_nchwc_acc32.opencl", + name="conv2d_nchwc_acc32.image2d", plevel=20 ) else: @@ -62,13 +62,13 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nchwc), wrap_topi_schedule(topi.adreno.schedule_depthwise_conv2d_nchwc), - name="depthwise_conv2d_nchwc.opencl", + name="depthwise_conv2d_nchwc.image2d", plevel=10 ) strategy.add_implementation( wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nchwc_acc32), wrap_topi_schedule(topi.adreno.schedule_depthwise_conv2d_nchwc_acc32), - name="depthwise_conv2d_nchwc_acc32.opencl", + name="depthwise_conv2d_nchwc_acc32.image2d", plevel=20 ) else: diff --git a/python/tvm/topi/adreno/conv2d.py b/python/tvm/topi/adreno/conv2d.py index b2596d24fd9c..a1f6349a31c9 100644 --- a/python/tvm/topi/adreno/conv2d.py +++ b/python/tvm/topi/adreno/conv2d.py @@ -25,43 +25,43 @@ from ..utils import get_const_tuple, traverse_inline -@autotvm.register_topi_compute("conv2d_nchwc.opencl") +@autotvm.register_topi_compute("conv2d_nchwc.image2d") def conv2d_nchwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): """Compute conv2d with NCHWc layout""" args={"memory" : "texture", "shared" : False, "accumulator" : "float16"} return compute_conv2d_NCHWc_KCRSk(data, kernel, strides, padding, dilation, out_dtype, args=args) -@autotvm.register_topi_compute("conv2d_nchwc_acc32.opencl") +@autotvm.register_topi_compute("conv2d_nchwc_acc32.image2d") def conv2d_nchwc_acc32(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): """Compute conv2d with NCHWc layout""" args={"memory" : "texture", "shared" : False, "accumulator" : "float32"} return compute_conv2d_NCHWc_KCRSk(data, kernel, strides, padding, dilation, out_dtype, args=args) -@autotvm.register_topi_schedule("conv2d_nchwc.opencl") +@autotvm.register_topi_schedule("conv2d_nchwc.image2d") def schedule_conv2d_nchwc(cfg, outs): return schedule_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc16") -@autotvm.register_topi_schedule("conv2d_nchwc_acc32.opencl") +@autotvm.register_topi_schedule("conv2d_nchwc_acc32.image2d") def schedule_conv2d_nchwc_acc32(cfg, outs): return schedule_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc32") -@autotvm.register_topi_compute("depthwise_conv2d_nchwc.opencl") +@autotvm.register_topi_compute("depthwise_conv2d_nchwc.image2d") def depthwise_conv2d_nchwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): """Compute depthwise_conv2d with NCHWc layout""" args={"memory" : "texture", "shared" : False, "accumulator" : "float16"} return compute_depthwise_conv2d_NCHWc_KCRSk(data, kernel, strides, padding, dilation, out_dtype, args=args) -@autotvm.register_topi_compute("depthwise_conv2d_nchwc_acc32.opencl") +@autotvm.register_topi_compute("depthwise_conv2d_nchwc_acc32.image2d") def depthwise_conv2d_nchwc_acc32(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): """Compute depthwise_conv2d with NCHWc layout""" args={"memory" : "texture", "shared" : False, "accumulator" : "float32"} return compute_depthwise_conv2d_NCHWc_KCRSk(data, kernel, strides, padding, dilation, out_dtype, args=args) -@autotvm.register_topi_schedule("depthwise_conv2d_nchwc.opencl") +@autotvm.register_topi_schedule("depthwise_conv2d_nchwc.image2d") def schedule_depthwise_conv2d_nchwc(cfg, outs): return schedule_depthwise_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc16") -@autotvm.register_topi_schedule("depthwise_conv2d_nchwc_acc32.opencl") +@autotvm.register_topi_schedule("depthwise_conv2d_nchwc_acc32.image2d") def schedule_depthwise_conv2d_nchwc_acc32(cfg, outs): return schedule_depthwise_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc32") From f4e58e56f061f5151857d978d4a8adc762c409d8 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Thu, 11 Mar 2021 10:22:43 -0800 Subject: [PATCH 07/15] mali -> cuda for conv2d_nchw --- python/tvm/relay/op/strategy/adreno.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py index faff8ea4846f..a5b0a22c7dab 100644 --- a/python/tvm/relay/op/strategy/adreno.py +++ b/python/tvm/relay/op/strategy/adreno.py @@ -38,9 +38,9 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): if groups == 1: if data_layout == "NCHW" and kernel_layout == "OIHW": strategy.add_implementation( - wrap_compute_conv2d(topi.mali.conv2d_nchw_spatial_pack), - wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_spatial_pack), - name="conv2d_nchw_spatial_pack.mali", + wrap_compute_conv2d(topi.cuda.conv2d_nchw), + wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw), + name="conv2d_nchw.cuda", ) elif data_layout == "NCHW4c" and kernel_layout == "OIHW4o": strategy.add_implementation( @@ -58,7 +58,13 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): else: raise RuntimeError("Layout not supported: ("+data_layout+", "+kernel_layout+") - only support NCHW4c / OIHW4o layouts for conv2d") elif is_depthwise_conv2d(data.shape, data_layout, kernel.shape, kernel_layout, groups): - if data_layout == "NCHW4c" and kernel_layout == "OIHW4o": + if data_layout == "NCHW" and kernel_layout == "OIHW": + strategy.add_implementation( + wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw), + wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw), + name="depthwise_conv2d_nchw.cuda", + ) + elif data_layout == "NCHW4c" and kernel_layout == "OIHW4o": strategy.add_implementation( wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nchwc), wrap_topi_schedule(topi.adreno.schedule_depthwise_conv2d_nchwc), From 72bc4957b2b613a1b9332a2aad96fce29cb2421b Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Wed, 17 Mar 2021 10:53:04 -0700 Subject: [PATCH 08/15] [Part 1/3] Support texture:weight lowering convention for externally provided texture buffers. Need to propagate this to allocated textures when cache_read(texture) is used for weights. --- python/tvm/topi/adreno/conv2d.py | 248 +++++++++++++++---------------- 1 file changed, 124 insertions(+), 124 deletions(-) diff --git a/python/tvm/topi/adreno/conv2d.py b/python/tvm/topi/adreno/conv2d.py index a1f6349a31c9..a3e3842b4993 100644 --- a/python/tvm/topi/adreno/conv2d.py +++ b/python/tvm/topi/adreno/conv2d.py @@ -115,34 +115,34 @@ def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dty ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") - if args["memory"] != None: - # NCHWc x KCRSk - # texture: NCH|W|c - # texture: K|CRS|k - Filter_tx = te.compute( - (num_filter_chunk, channel * kernel_h * kernel_w, num_filter_block), - lambda ffc, crs, ffb: Filter[ffc, crs // (kernel_h * kernel_w), (crs // kernel_w) % kernel_h, crs % kernel_w, ffb], - name = "packed_filter" - ) - conv = te.compute( - (batch, num_filter_chunk, out_height, out_width, num_filter_block), - lambda nn, ffc, yy, xx, ffb: te.sum( - (temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb] - * Filter_tx[ffc, ((rcc * in_channel_block + rcb)*kernel_h + ry)*kernel_w + rx, ffb]).astype(args["accumulator"]), - axis=[rcc, rcb, ry, rx], - ), - tag="conv2d_nchwc", - ) - else: - conv = te.compute( - (batch, num_filter_chunk, out_height, out_width, num_filter_block), - lambda nn, ffc, yy, xx, ffb: te.sum( - (temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb] - * Filter[ffc, rcc * in_channel_block + rcb, ry, rx, ffb]).astype(args["accumulator"]), - axis=[rcc, rcb, ry, rx], - ), - tag="conv2d_nchwc", - ) + # if args["memory"] != None: + # # NCHWc x KCRSk + # # texture: NCH|W|c + # # texture: K|CRS|k + # Filter_tx = te.compute( + # (num_filter_chunk, channel * kernel_h * kernel_w, num_filter_block), + # lambda ffc, crs, ffb: Filter[ffc, crs // (kernel_h * kernel_w), (crs // kernel_w) % kernel_h, crs % kernel_w, ffb], + # name = "packed_filter" + # ) + # conv = te.compute( + # (batch, num_filter_chunk, out_height, out_width, num_filter_block), + # lambda nn, ffc, yy, xx, ffb: te.sum( + # (temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb] + # * Filter_tx[ffc, ((rcc * in_channel_block + rcb)*kernel_h + ry)*kernel_w + rx, ffb]).astype(args["accumulator"]), + # axis=[rcc, rcb, ry, rx], + # ), + # tag="conv2d_nchwc", + # ) + # else: + conv = te.compute( + (batch, num_filter_chunk, out_height, out_width, num_filter_block), + lambda nn, ffc, yy, xx, ffb: te.sum( + (temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb] + * Filter[ffc, rcc * in_channel_block + rcb, ry, rx, ffb]).astype(args["accumulator"]), + axis=[rcc, rcb, ry, rx], + ), + tag="conv2d_nchwc", + ) return te.compute(conv.shape, lambda n,fc,y,x,fb: conv[n,fc,y,x,fb].astype("float16"), tag="cast_from_acc" + args["accumulator"][-2:]) def schedule_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): @@ -167,18 +167,18 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): cfg.define_knob("unroll_explicit", [0, 1]) ##### space definition end ##### - if args["memory"] != None: - pad_data, flattened_kernel = s[conv].op.input_tensors - kernel = s[flattened_kernel].op.input_tensors[0] - s[flattened_kernel].compute_inline() - else: - pad_data, kernel = s[conv].op.input_tensors - flattened_kernel = kernel + # if args["memory"] != None: + # pad_data, flattened_kernel = s[conv].op.input_tensors + # kernel = s[flattened_kernel].op.input_tensors[0] + # s[flattened_kernel].compute_inline() + # else: + pad_data, kernel = s[conv].op.input_tensors + #flattened_kernel = kernel s[pad_data].compute_inline() if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: s[kernel].compute_inline() - kernel = flattened_kernel + #kernel = flattened_kernel # conv only if conv.op in s.outputs: @@ -199,25 +199,25 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): OL = conv # create cache stage - if args["memory"] != None: - AT = s.cache_read(pad_data, args["memory"], [OL]) - WT = s.cache_read(kernel, args["memory"], [OL]) - def copy_to_texture(stage): - axes = s[stage].op.axis - fused = s[stage].fuse(*axes[:-1]) - block, thread = s[stage].split(fused, factor=32) - s[stage].vectorize(axes[-1]) - s[stage].bind(block, te.thread_axis("blockIdx.x")) - s[stage].bind(thread, te.thread_axis("threadIdx.x")) - copy_to_texture(AT) - copy_to_texture(WT) - - if args["shared"]: - AA = s.cache_read(AT, "shared", [OL]) - WW = s.cache_read(WT, "shared", [OL]) - else: - AA = s.cache_read(pad_data, "shared", [OL]) - WW = s.cache_read(kernel, "shared", [OL]) + # if args["memory"] != None: + # AT = s.cache_read(pad_data, args["memory"], [OL]) + # WT = s.cache_read(kernel, args["memory"], [OL]) + # def copy_to_texture(stage): + # axes = s[stage].op.axis + # fused = s[stage].fuse(*axes[:-1]) + # block, thread = s[stage].split(fused, factor=32) + # s[stage].vectorize(axes[-1]) + # s[stage].bind(block, te.thread_axis("blockIdx.x")) + # s[stage].bind(thread, te.thread_axis("threadIdx.x")) + # copy_to_texture(AT) + # copy_to_texture(WT) + + # if args["shared"]: + # AA = s.cache_read(AT, "shared", [OL]) + # WW = s.cache_read(WT, "shared", [OL]) + # else: + # AA = s.cache_read(pad_data, "shared", [OL]) + # WW = s.cache_read(kernel, "shared", [OL]) # tile and bind spatial axes n, fc, y, x, fb = s[output].op.axis @@ -280,11 +280,11 @@ def copy_to_texture(stage): s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val) N, OCC, OH, OW, OCB = get_const_tuple(output.shape) - if args["memory"] != None: - _, ICKHKW, _ = get_const_tuple(kernel.shape) - else: - _, IC, KH, KW, _ = get_const_tuple(kernel.shape) - ICKHKW = IC*KH*KW + # if args["memory"] != None: + # _, ICKHKW, _ = get_const_tuple(kernel.shape) + # else: + _, IC, KH, KW, _ = get_const_tuple(kernel.shape) + ICKHKW = IC*KH*KW if isinstance(N, int): @@ -341,36 +341,36 @@ def compute_depthwise_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilatio rx = te.reduce_axis((0, kernel_w), name="rx") - if args["memory"] != None: - # NCHWc x CMRSc = [N,(C//4)M,OH,OW, 4c] - # NCHWc x CMRS - # texture: NCH|W|c - # texture: C|MRS|c - Filter_tx = te.compute( - (channel_chunk, channel_multiplier * kernel_h * kernel_w, channel_block), - lambda ffc, mrs, ffb: Filter[ffc, mrs // (kernel_h * kernel_w), (mrs // kernel_w) % kernel_h, mrs % kernel_w, ffb], - name = "packed_filter" - ) - - conv = te.compute( - (batch, out_channel_chunk, out_height, out_width, channel_block), - lambda nn, ffc, yy, xx, ffb: te.sum( - (temp[nn, ffc//channel_multiplier, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ffb] - * Filter_tx[ffc//channel_multiplier, ((ffc % channel_multiplier) * kernel_h + ry) * kernel_w + rx, ffb]).astype(args["accumulator"]), - axis=[ry, rx], - ), - tag="depthwise_conv2d_nchwc_kcrsk_texture", - ) - else: - conv = te.compute( - (batch, out_channel_chunk, out_height, out_width, channel_block), - lambda nn, ffc, yy, xx, ffb: te.sum( - (temp[nn, ffc//channel_multiplier, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ffb] - * Filter[ffc//channel_multiplier, ffc % channel_multiplier, ry, rx, ffb]).astype(args["accumulator"]), - axis=[ry, rx], - ), - tag="depthwise_conv2d_nchwc_kcrsk", - ) + # if args["memory"] != None: + # # NCHWc x CMRSc = [N,(C//4)M,OH,OW, 4c] + # # NCHWc x CMRS + # # texture: NCH|W|c + # # texture: C|MRS|c + # Filter_tx = te.compute( + # (channel_chunk, channel_multiplier * kernel_h * kernel_w, channel_block), + # lambda ffc, mrs, ffb: Filter[ffc, mrs // (kernel_h * kernel_w), (mrs // kernel_w) % kernel_h, mrs % kernel_w, ffb], + # name = "packed_filter" + # ) + + # conv = te.compute( + # (batch, out_channel_chunk, out_height, out_width, channel_block), + # lambda nn, ffc, yy, xx, ffb: te.sum( + # (temp[nn, ffc//channel_multiplier, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ffb] + # * Filter_tx[ffc//channel_multiplier, ((ffc % channel_multiplier) * kernel_h + ry) * kernel_w + rx, ffb]).astype(args["accumulator"]), + # axis=[ry, rx], + # ), + # tag="depthwise_conv2d_nchwc_kcrsk_texture", + # ) + # else: + conv = te.compute( + (batch, out_channel_chunk, out_height, out_width, channel_block), + lambda nn, ffc, yy, xx, ffb: te.sum( + (temp[nn, ffc//channel_multiplier, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ffb] + * Filter[ffc//channel_multiplier, ffc % channel_multiplier, ry, rx, ffb]).astype(args["accumulator"]), + axis=[ry, rx], + ), + tag="depthwise_conv2d_nchwc_kcrsk", + ) return te.compute(conv.shape, lambda n,ffc,y,x,ffb: conv[n,ffc,y,x,ffb].astype("float16"), tag="cast_from_acc" + args["accumulator"][-2:]) def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): @@ -394,18 +394,18 @@ def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): cfg.define_knob("unroll_explicit", [0, 1]) ##### space definition end ##### - if args["memory"] != None: - pad_data, flattened_kernel = s[conv].op.input_tensors - kernel = s[flattened_kernel].op.input_tensors[0] - s[flattened_kernel].compute_inline() - else: - pad_data, kernel = s[conv].op.input_tensors - flattened_kernel = kernel + # if args["memory"] != None: + # pad_data, flattened_kernel = s[conv].op.input_tensors + # kernel = s[flattened_kernel].op.input_tensors[0] + # s[flattened_kernel].compute_inline() + # else: + pad_data, kernel = s[conv].op.input_tensors + #flattened_kernel = kernel s[pad_data].compute_inline() if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: s[kernel].compute_inline() - kernel = flattened_kernel + #kernel = flattened_kernel # conv only if conv.op in s.outputs: @@ -426,25 +426,25 @@ def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): OL = conv # create cache stage - if args["memory"] != None: - AT = s.cache_read(pad_data, args["memory"], [OL]) - WT = s.cache_read(kernel, args["memory"], [OL]) - def copy_to_texture(stage): - axes = s[stage].op.axis - fused = s[stage].fuse(*axes[:-1]) - block, thread = s[stage].split(fused, factor=32) - s[stage].vectorize(axes[-1]) - s[stage].bind(block, te.thread_axis("blockIdx.x")) - s[stage].bind(thread, te.thread_axis("threadIdx.x")) - copy_to_texture(AT) - copy_to_texture(WT) - - if args["shared"]: - AA = s.cache_read(AT, "shared", [OL]) - WW = s.cache_read(WT, "shared", [OL]) - else: - AA = s.cache_read(pad_data, "shared", [OL]) - WW = s.cache_read(kernel, "shared", [OL]) + # if args["memory"] != None: + # AT = s.cache_read(pad_data, args["memory"], [OL]) + # WT = s.cache_read(kernel, args["memory"], [OL]) + # def copy_to_texture(stage): + # axes = s[stage].op.axis + # fused = s[stage].fuse(*axes[:-1]) + # block, thread = s[stage].split(fused, factor=32) + # s[stage].vectorize(axes[-1]) + # s[stage].bind(block, te.thread_axis("blockIdx.x")) + # s[stage].bind(thread, te.thread_axis("threadIdx.x")) + # copy_to_texture(AT) + # copy_to_texture(WT) + + # if args["shared"]: + # AA = s.cache_read(AT, "shared", [OL]) + # WW = s.cache_read(WT, "shared", [OL]) + # else: + # AA = s.cache_read(pad_data, "shared", [OL]) + # WW = s.cache_read(kernel, "shared", [OL]) # tile and bind spatial axes n, fc, y, x, fb = s[output].op.axis @@ -507,13 +507,13 @@ def copy_to_texture(stage): N, OCC, OH, OW, OCB = get_const_tuple(output.shape) # OC = OCC * OCB = IC * M # M = OC // IC == (OCC * OCB) // ICC * ICB - if args["memory"] != None: - ICC, MKHKW, ICB = get_const_tuple(kernel.shape) - M = (OCC * OCB) // (ICC * ICB) - KHKW = MKHKW // M - else: - ICC, M, KH, KW, ICB = get_const_tuple(kernel.shape) - KHKW = KH*KW + # if args["memory"] != None: + # ICC, MKHKW, ICB = get_const_tuple(kernel.shape) + # M = (OCC * OCB) // (ICC * ICB) + # KHKW = MKHKW // M + # else: + ICC, M, KH, KW, ICB = get_const_tuple(kernel.shape) + KHKW = KH*KW if isinstance(N, int): cfg.add_flop(2 * N * OH * OW * OCC * OCB * KHKW) From fda5fa51c3dcdf8babd7a0837667ac52359ebbe2 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Thu, 11 Mar 2021 11:19:04 -0800 Subject: [PATCH 09/15] Use cache_read("texture") when tuning via autotvm to simulate graph runtime behavior of providing external texture buffers. --- python/tvm/topi/adreno/conv2d.py | 275 ++++++++++++++++--------------- 1 file changed, 141 insertions(+), 134 deletions(-) diff --git a/python/tvm/topi/adreno/conv2d.py b/python/tvm/topi/adreno/conv2d.py index a3e3842b4993..fcc30cb07a70 100644 --- a/python/tvm/topi/adreno/conv2d.py +++ b/python/tvm/topi/adreno/conv2d.py @@ -28,13 +28,13 @@ @autotvm.register_topi_compute("conv2d_nchwc.image2d") def conv2d_nchwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): """Compute conv2d with NCHWc layout""" - args={"memory" : "texture", "shared" : False, "accumulator" : "float16"} + args={"shared" : False, "accumulator" : "float16"} return compute_conv2d_NCHWc_KCRSk(data, kernel, strides, padding, dilation, out_dtype, args=args) @autotvm.register_topi_compute("conv2d_nchwc_acc32.image2d") def conv2d_nchwc_acc32(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): """Compute conv2d with NCHWc layout""" - args={"memory" : "texture", "shared" : False, "accumulator" : "float32"} + args={"shared" : False, "accumulator" : "float32"} return compute_conv2d_NCHWc_KCRSk(data, kernel, strides, padding, dilation, out_dtype, args=args) @autotvm.register_topi_schedule("conv2d_nchwc.image2d") @@ -48,13 +48,13 @@ def schedule_conv2d_nchwc_acc32(cfg, outs): @autotvm.register_topi_compute("depthwise_conv2d_nchwc.image2d") def depthwise_conv2d_nchwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): """Compute depthwise_conv2d with NCHWc layout""" - args={"memory" : "texture", "shared" : False, "accumulator" : "float16"} + args={"shared" : False, "accumulator" : "float16"} return compute_depthwise_conv2d_NCHWc_KCRSk(data, kernel, strides, padding, dilation, out_dtype, args=args) @autotvm.register_topi_compute("depthwise_conv2d_nchwc_acc32.image2d") def depthwise_conv2d_nchwc_acc32(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): """Compute depthwise_conv2d with NCHWc layout""" - args={"memory" : "texture", "shared" : False, "accumulator" : "float32"} + args={"shared" : False, "accumulator" : "float32"} return compute_depthwise_conv2d_NCHWc_KCRSk(data, kernel, strides, padding, dilation, out_dtype, args=args) @autotvm.register_topi_schedule("depthwise_conv2d_nchwc.image2d") @@ -71,7 +71,7 @@ def schedule_conv2d_nchwc_impl(cfg, outs, tag): s = te.create_schedule([x.op for x in outs]) def _callback(op): if op.tag == tag: - args={"memory" : "texture", "shared" : False} + args={"shared" : False} schedule_conv2d_NCHWc_KCRSk(cfg, s, op.output(0), args) traverse_inline(s, outs[0].op, _callback) @@ -115,34 +115,41 @@ def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dty ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") - # if args["memory"] != None: - # # NCHWc x KCRSk - # # texture: NCH|W|c - # # texture: K|CRS|k - # Filter_tx = te.compute( - # (num_filter_chunk, channel * kernel_h * kernel_w, num_filter_block), - # lambda ffc, crs, ffb: Filter[ffc, crs // (kernel_h * kernel_w), (crs // kernel_w) % kernel_h, crs % kernel_w, ffb], - # name = "packed_filter" - # ) - # conv = te.compute( - # (batch, num_filter_chunk, out_height, out_width, num_filter_block), - # lambda nn, ffc, yy, xx, ffb: te.sum( - # (temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb] - # * Filter_tx[ffc, ((rcc * in_channel_block + rcb)*kernel_h + ry)*kernel_w + rx, ffb]).astype(args["accumulator"]), - # axis=[rcc, rcb, ry, rx], - # ), - # tag="conv2d_nchwc", - # ) - # else: - conv = te.compute( - (batch, num_filter_chunk, out_height, out_width, num_filter_block), - lambda nn, ffc, yy, xx, ffb: te.sum( - (temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb] - * Filter[ffc, rcc * in_channel_block + rcb, ry, rx, ffb]).astype(args["accumulator"]), - axis=[rcc, rcb, ry, rx], - ), - tag="conv2d_nchwc", - ) + # When tuning, insert a cache_read("texture") stage to properly test + # performance of kernels that utlize texture inputs. The cache_read + # is not needed when using the graph_runtime which supports passing + # in external texture buffers. This can be removed once AutoTVM tuning + # supports capturing this runtime information during task extraction + # or once texture lowering in tir.TextureFlatten supports cache_read + # cancellation when padding is utilized. + if autotvm.GLOBAL_SCOPE.in_tuning: + # NCHWc x KCRSk + # texture: NCH|W|c + # texture: K|CRS|k + Filter_tx = te.compute( + (num_filter_chunk, channel * kernel_h * kernel_w, num_filter_block), + lambda ffc, crs, ffb: Filter[ffc, crs // (kernel_h * kernel_w), (crs // kernel_w) % kernel_h, crs % kernel_w, ffb], + name = "packed_filter" + ) + conv = te.compute( + (batch, num_filter_chunk, out_height, out_width, num_filter_block), + lambda nn, ffc, yy, xx, ffb: te.sum( + (temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb] + * Filter_tx[ffc, ((rcc * in_channel_block + rcb)*kernel_h + ry)*kernel_w + rx, ffb]).astype(args["accumulator"]), + axis=[rcc, rcb, ry, rx], + ), + tag="conv2d_nchwc", + ) + else: + conv = te.compute( + (batch, num_filter_chunk, out_height, out_width, num_filter_block), + lambda nn, ffc, yy, xx, ffb: te.sum( + (temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb] + * Filter[ffc, rcc * in_channel_block + rcb, ry, rx, ffb]).astype(args["accumulator"]), + axis=[rcc, rcb, ry, rx], + ), + tag="conv2d_nchwc", + ) return te.compute(conv.shape, lambda n,fc,y,x,fb: conv[n,fc,y,x,fb].astype("float16"), tag="cast_from_acc" + args["accumulator"][-2:]) def schedule_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): @@ -167,18 +174,18 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): cfg.define_knob("unroll_explicit", [0, 1]) ##### space definition end ##### - # if args["memory"] != None: - # pad_data, flattened_kernel = s[conv].op.input_tensors - # kernel = s[flattened_kernel].op.input_tensors[0] - # s[flattened_kernel].compute_inline() - # else: - pad_data, kernel = s[conv].op.input_tensors - #flattened_kernel = kernel + if autotvm.GLOBAL_SCOPE.in_tuning: + pad_data, flattened_kernel = s[conv].op.input_tensors + kernel = s[flattened_kernel].op.input_tensors[0] + s[flattened_kernel].compute_inline() + else: + pad_data, kernel = s[conv].op.input_tensors + flattened_kernel = kernel s[pad_data].compute_inline() if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: s[kernel].compute_inline() - #kernel = flattened_kernel + kernel = flattened_kernel # conv only if conv.op in s.outputs: @@ -199,25 +206,25 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): OL = conv # create cache stage - # if args["memory"] != None: - # AT = s.cache_read(pad_data, args["memory"], [OL]) - # WT = s.cache_read(kernel, args["memory"], [OL]) - # def copy_to_texture(stage): - # axes = s[stage].op.axis - # fused = s[stage].fuse(*axes[:-1]) - # block, thread = s[stage].split(fused, factor=32) - # s[stage].vectorize(axes[-1]) - # s[stage].bind(block, te.thread_axis("blockIdx.x")) - # s[stage].bind(thread, te.thread_axis("threadIdx.x")) - # copy_to_texture(AT) - # copy_to_texture(WT) - - # if args["shared"]: - # AA = s.cache_read(AT, "shared", [OL]) - # WW = s.cache_read(WT, "shared", [OL]) - # else: - # AA = s.cache_read(pad_data, "shared", [OL]) - # WW = s.cache_read(kernel, "shared", [OL]) + if autotvm.GLOBAL_SCOPE.in_tuning: + AT = s.cache_read(pad_data, "texture", [OL]) + WT = s.cache_read(kernel, "texture", [OL]) + def copy_to_texture(stage): + axes = s[stage].op.axis + fused = s[stage].fuse(*axes[:-1]) + block, thread = s[stage].split(fused, factor=32) + s[stage].vectorize(axes[-1]) + s[stage].bind(block, te.thread_axis("blockIdx.x")) + s[stage].bind(thread, te.thread_axis("threadIdx.x")) + copy_to_texture(AT) + copy_to_texture(WT) + + if args["shared"]: + AA = s.cache_read(AT, "shared", [OL]) + WW = s.cache_read(WT, "shared", [OL]) + elif args["shared"]: + AA = s.cache_read(pad_data, "shared", [OL]) + WW = s.cache_read(kernel, "shared", [OL]) # tile and bind spatial axes n, fc, y, x, fb = s[output].op.axis @@ -256,12 +263,12 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): s[OL].vectorize(fb) s[OL].unroll(rcb) - if args["memory"] == None or args["shared"]: + if args["shared"]: s[AA].compute_at(s[OL], rxo) s[WW].compute_at(s[OL], rxo) # cooperative fetching for load in [AA, WW]: - if args["memory"] != None and load == WW: + if autotvm.GLOBAL_SCOPE.in_tuning and load == WW: n, fyx, v = s[load].op.axis fused = s[load].fuse(n, fyx) else: @@ -280,11 +287,11 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val) N, OCC, OH, OW, OCB = get_const_tuple(output.shape) - # if args["memory"] != None: - # _, ICKHKW, _ = get_const_tuple(kernel.shape) - # else: - _, IC, KH, KW, _ = get_const_tuple(kernel.shape) - ICKHKW = IC*KH*KW + if autotvm.GLOBAL_SCOPE.in_tuning: + _, ICKHKW, _ = get_const_tuple(kernel.shape) + else: + _, IC, KH, KW, _ = get_const_tuple(kernel.shape) + ICKHKW = IC*KH*KW if isinstance(N, int): @@ -297,7 +304,7 @@ def schedule_depthwise_conv2d_nchwc_impl(cfg, outs, tag): s = te.create_schedule([x.op for x in outs]) def _callback(op): if op.tag == tag: - args={"memory" : "texture", "shared" : False} + args={"shared" : False} schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, op.output(0), args) traverse_inline(s, outs[0].op, _callback) @@ -341,36 +348,36 @@ def compute_depthwise_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilatio rx = te.reduce_axis((0, kernel_w), name="rx") - # if args["memory"] != None: - # # NCHWc x CMRSc = [N,(C//4)M,OH,OW, 4c] - # # NCHWc x CMRS - # # texture: NCH|W|c - # # texture: C|MRS|c - # Filter_tx = te.compute( - # (channel_chunk, channel_multiplier * kernel_h * kernel_w, channel_block), - # lambda ffc, mrs, ffb: Filter[ffc, mrs // (kernel_h * kernel_w), (mrs // kernel_w) % kernel_h, mrs % kernel_w, ffb], - # name = "packed_filter" - # ) - - # conv = te.compute( - # (batch, out_channel_chunk, out_height, out_width, channel_block), - # lambda nn, ffc, yy, xx, ffb: te.sum( - # (temp[nn, ffc//channel_multiplier, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ffb] - # * Filter_tx[ffc//channel_multiplier, ((ffc % channel_multiplier) * kernel_h + ry) * kernel_w + rx, ffb]).astype(args["accumulator"]), - # axis=[ry, rx], - # ), - # tag="depthwise_conv2d_nchwc_kcrsk_texture", - # ) - # else: - conv = te.compute( - (batch, out_channel_chunk, out_height, out_width, channel_block), - lambda nn, ffc, yy, xx, ffb: te.sum( - (temp[nn, ffc//channel_multiplier, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ffb] - * Filter[ffc//channel_multiplier, ffc % channel_multiplier, ry, rx, ffb]).astype(args["accumulator"]), - axis=[ry, rx], - ), - tag="depthwise_conv2d_nchwc_kcrsk", - ) + if autotvm.GLOBAL_SCOPE.in_tuning: + # NCHWc x CMRSc = [N,(C//4)M,OH,OW, 4c] + # NCHWc x CMRS + # texture: NCH|W|c + # texture: C|MRS|c + Filter_tx = te.compute( + (channel_chunk, channel_multiplier * kernel_h * kernel_w, channel_block), + lambda ffc, mrs, ffb: Filter[ffc, mrs // (kernel_h * kernel_w), (mrs // kernel_w) % kernel_h, mrs % kernel_w, ffb], + name = "packed_filter" + ) + + conv = te.compute( + (batch, out_channel_chunk, out_height, out_width, channel_block), + lambda nn, ffc, yy, xx, ffb: te.sum( + (temp[nn, ffc//channel_multiplier, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ffb] + * Filter_tx[ffc//channel_multiplier, ((ffc % channel_multiplier) * kernel_h + ry) * kernel_w + rx, ffb]).astype(args["accumulator"]), + axis=[ry, rx], + ), + tag="depthwise_conv2d_nchwc_kcrsk_texture", + ) + else: + conv = te.compute( + (batch, out_channel_chunk, out_height, out_width, channel_block), + lambda nn, ffc, yy, xx, ffb: te.sum( + (temp[nn, ffc//channel_multiplier, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ffb] + * Filter[ffc//channel_multiplier, ffc % channel_multiplier, ry, rx, ffb]).astype(args["accumulator"]), + axis=[ry, rx], + ), + tag="depthwise_conv2d_nchwc_kcrsk", + ) return te.compute(conv.shape, lambda n,ffc,y,x,ffb: conv[n,ffc,y,x,ffb].astype("float16"), tag="cast_from_acc" + args["accumulator"][-2:]) def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): @@ -394,18 +401,18 @@ def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): cfg.define_knob("unroll_explicit", [0, 1]) ##### space definition end ##### - # if args["memory"] != None: - # pad_data, flattened_kernel = s[conv].op.input_tensors - # kernel = s[flattened_kernel].op.input_tensors[0] - # s[flattened_kernel].compute_inline() - # else: - pad_data, kernel = s[conv].op.input_tensors - #flattened_kernel = kernel + if autotvm.GLOBAL_SCOPE.in_tuning: + pad_data, flattened_kernel = s[conv].op.input_tensors + kernel = s[flattened_kernel].op.input_tensors[0] + s[flattened_kernel].compute_inline() + else: + pad_data, kernel = s[conv].op.input_tensors + flattened_kernel = kernel s[pad_data].compute_inline() if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: s[kernel].compute_inline() - #kernel = flattened_kernel + kernel = flattened_kernel # conv only if conv.op in s.outputs: @@ -426,25 +433,25 @@ def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): OL = conv # create cache stage - # if args["memory"] != None: - # AT = s.cache_read(pad_data, args["memory"], [OL]) - # WT = s.cache_read(kernel, args["memory"], [OL]) - # def copy_to_texture(stage): - # axes = s[stage].op.axis - # fused = s[stage].fuse(*axes[:-1]) - # block, thread = s[stage].split(fused, factor=32) - # s[stage].vectorize(axes[-1]) - # s[stage].bind(block, te.thread_axis("blockIdx.x")) - # s[stage].bind(thread, te.thread_axis("threadIdx.x")) - # copy_to_texture(AT) - # copy_to_texture(WT) - - # if args["shared"]: - # AA = s.cache_read(AT, "shared", [OL]) - # WW = s.cache_read(WT, "shared", [OL]) - # else: - # AA = s.cache_read(pad_data, "shared", [OL]) - # WW = s.cache_read(kernel, "shared", [OL]) + if autotvm.GLOBAL_SCOPE.in_tuning: + AT = s.cache_read(pad_data, "texture", [OL]) + WT = s.cache_read(kernel, "texture", [OL]) + def copy_to_texture(stage): + axes = s[stage].op.axis + fused = s[stage].fuse(*axes[:-1]) + block, thread = s[stage].split(fused, factor=32) + s[stage].vectorize(axes[-1]) + s[stage].bind(block, te.thread_axis("blockIdx.x")) + s[stage].bind(thread, te.thread_axis("threadIdx.x")) + copy_to_texture(AT) + copy_to_texture(WT) + + if args["shared"]: + AA = s.cache_read(AT, "shared", [OL]) + WW = s.cache_read(WT, "shared", [OL]) + elif args["shared"]: + AA = s.cache_read(pad_data, "shared", [OL]) + WW = s.cache_read(kernel, "shared", [OL]) # tile and bind spatial axes n, fc, y, x, fb = s[output].op.axis @@ -481,12 +488,12 @@ def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): s[OL].vectorize(fb) #s[OL].unroll() - if args["memory"] == None or args["shared"]: + if args["shared"]: s[AA].compute_at(s[OL], rxo) s[WW].compute_at(s[OL], rxo) # cooperative fetching for load in [AA, WW]: - if args["memory"] != None and load == WW: + if autotvm.GLOBAL_SCOPE.in_tuning: n, fyx, v = s[load].op.axis fused = s[load].fuse(n, fyx) else: @@ -507,13 +514,13 @@ def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): N, OCC, OH, OW, OCB = get_const_tuple(output.shape) # OC = OCC * OCB = IC * M # M = OC // IC == (OCC * OCB) // ICC * ICB - # if args["memory"] != None: - # ICC, MKHKW, ICB = get_const_tuple(kernel.shape) - # M = (OCC * OCB) // (ICC * ICB) - # KHKW = MKHKW // M - # else: - ICC, M, KH, KW, ICB = get_const_tuple(kernel.shape) - KHKW = KH*KW + if autotvm.GLOBAL_SCOPE.in_tuning: + ICC, MKHKW, ICB = get_const_tuple(kernel.shape) + M = (OCC * OCB) // (ICC * ICB) + KHKW = MKHKW // M + else: + ICC, M, KH, KW, ICB = get_const_tuple(kernel.shape) + KHKW = KH*KW if isinstance(N, int): cfg.add_flop(2 * N * OH * OW * OCC * OCB * KHKW) From d1a2f1589a3e130d81fe3ebe8f1f4d3ac7388e38 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Tue, 16 Mar 2021 22:25:36 -0700 Subject: [PATCH 10/15] Remove comment --- python/tvm/topi/adreno/conv2d.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/topi/adreno/conv2d.py b/python/tvm/topi/adreno/conv2d.py index fcc30cb07a70..0650b1095b61 100644 --- a/python/tvm/topi/adreno/conv2d.py +++ b/python/tvm/topi/adreno/conv2d.py @@ -258,7 +258,6 @@ def copy_to_texture(stage): ryo, ryi = cfg["tile_ry"].apply(s, OL, ry) rxo, rxi = cfg["tile_rx"].apply(s, OL, rx) - # TODO(csullivan): check position of rcb s[OL].reorder(rco, ryo, rxo, rci, ryi, rxi, rcb, n, fc, y, x, fb) s[OL].vectorize(fb) s[OL].unroll(rcb) From b83b9927e6e044be67cb1afaa129c53850c81704 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Tue, 4 May 2021 15:24:49 -0700 Subject: [PATCH 11/15] Update topi schedules to use global.texture scope. --- python/tvm/topi/adreno/conv2d.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/adreno/conv2d.py b/python/tvm/topi/adreno/conv2d.py index 0650b1095b61..72b378b9575d 100644 --- a/python/tvm/topi/adreno/conv2d.py +++ b/python/tvm/topi/adreno/conv2d.py @@ -115,7 +115,7 @@ def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dty ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") - # When tuning, insert a cache_read("texture") stage to properly test + # When tuning, insert a cache_read("global.texture") stage to properly test # performance of kernels that utlize texture inputs. The cache_read # is not needed when using the graph_runtime which supports passing # in external texture buffers. This can be removed once AutoTVM tuning @@ -207,8 +207,8 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): # create cache stage if autotvm.GLOBAL_SCOPE.in_tuning: - AT = s.cache_read(pad_data, "texture", [OL]) - WT = s.cache_read(kernel, "texture", [OL]) + AT = s.cache_read(pad_data, "global.texture", [OL]) + WT = s.cache_read(kernel, "global.texture-weight", [OL]) def copy_to_texture(stage): axes = s[stage].op.axis fused = s[stage].fuse(*axes[:-1]) @@ -433,8 +433,8 @@ def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): # create cache stage if autotvm.GLOBAL_SCOPE.in_tuning: - AT = s.cache_read(pad_data, "texture", [OL]) - WT = s.cache_read(kernel, "texture", [OL]) + AT = s.cache_read(pad_data, "global.texture", [OL]) + WT = s.cache_read(kernel, "global.texture-weight", [OL]) def copy_to_texture(stage): axes = s[stage].op.axis fused = s[stage].fuse(*axes[:-1]) From bd665ef55f8ff5e7732ed93735abad731564559b Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Thu, 22 Jul 2021 10:21:52 -0700 Subject: [PATCH 12/15] Apply black --- python/tvm/relay/op/strategy/adreno.py | 30 ++++-- python/tvm/target/target.py | 2 + python/tvm/topi/adreno/conv2d.py | 140 +++++++++++++++++++------ 3 files changed, 133 insertions(+), 39 deletions(-) diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py index a5b0a22c7dab..b65cfe0b014f 100644 --- a/python/tvm/relay/op/strategy/adreno.py +++ b/python/tvm/relay/op/strategy/adreno.py @@ -21,6 +21,7 @@ from .generic import * from .. import op as _op + @conv2d_strategy.register("adreno") def conv2d_strategy_adreno(attrs, inputs, out_type, target): """conv2d adreno strategy""" @@ -31,7 +32,9 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): groups = attrs.groups data_layout = attrs.data_layout kernel_layout = attrs.kernel_layout - assert out_type.dtype == "float16", "No float32 input/output tensor support is currently provided for Adreno GPU" + assert ( + out_type.dtype == "float16" + ), "No float32 input/output tensor support is currently provided for Adreno GPU" if dilation_h < 1 or dilation_w < 1: raise ValueError("dilation should be positive value") @@ -47,16 +50,22 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.adreno.conv2d_nchwc), wrap_topi_schedule(topi.adreno.schedule_conv2d_nchwc), name="conv2d_nchwc.image2d", - plevel=10 + plevel=10, ) strategy.add_implementation( wrap_compute_conv2d(topi.adreno.conv2d_nchwc_acc32), wrap_topi_schedule(topi.adreno.schedule_conv2d_nchwc_acc32), name="conv2d_nchwc_acc32.image2d", - plevel=20 + plevel=20, ) else: - raise RuntimeError("Layout not supported: ("+data_layout+", "+kernel_layout+") - only support NCHW4c / OIHW4o layouts for conv2d") + raise RuntimeError( + "Layout not supported: (" + + data_layout + + ", " + + kernel_layout + + ") - only support NCHW4c / OIHW4o layouts for conv2d" + ) elif is_depthwise_conv2d(data.shape, data_layout, kernel.shape, kernel_layout, groups): if data_layout == "NCHW" and kernel_layout == "OIHW": strategy.add_implementation( @@ -69,17 +78,22 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nchwc), wrap_topi_schedule(topi.adreno.schedule_depthwise_conv2d_nchwc), name="depthwise_conv2d_nchwc.image2d", - plevel=10 + plevel=10, ) strategy.add_implementation( wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nchwc_acc32), wrap_topi_schedule(topi.adreno.schedule_depthwise_conv2d_nchwc_acc32), name="depthwise_conv2d_nchwc_acc32.image2d", - plevel=20 + plevel=20, ) else: - raise RuntimeError("Layout not supported: ("+data_layout+", "+kernel_layout+") - only support NCHW4c / OIHW4o layouts for conv2d") + raise RuntimeError( + "Layout not supported: (" + + data_layout + + ", " + + kernel_layout + + ") - only support NCHW4c / OIHW4o layouts for conv2d" + ) else: raise RuntimeError("General group convolution is not currently supported") return strategy - diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index d67c0898627b..ba96a2a40619 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -263,6 +263,7 @@ def mali(model="unknown", options=None): opts = _merge_opts(opts, options) return Target(" ".join(["opencl"] + opts)) + def adreno(model="unknown", options=None): """Returns a Qualcomm GPU target. @@ -277,6 +278,7 @@ def adreno(model="unknown", options=None): opts = _merge_opts(opts, options) return Target(" ".join(["opencl"] + opts)) + def intel_graphics(model="unknown", options=None): """Returns an Intel Graphics target. diff --git a/python/tvm/topi/adreno/conv2d.py b/python/tvm/topi/adreno/conv2d.py index 72b378b9575d..3f457085f6e2 100644 --- a/python/tvm/topi/adreno/conv2d.py +++ b/python/tvm/topi/adreno/conv2d.py @@ -28,55 +28,75 @@ @autotvm.register_topi_compute("conv2d_nchwc.image2d") def conv2d_nchwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): """Compute conv2d with NCHWc layout""" - args={"shared" : False, "accumulator" : "float16"} - return compute_conv2d_NCHWc_KCRSk(data, kernel, strides, padding, dilation, out_dtype, args=args) + args = {"shared": False, "accumulator": "float16"} + return compute_conv2d_NCHWc_KCRSk( + data, kernel, strides, padding, dilation, out_dtype, args=args + ) + @autotvm.register_topi_compute("conv2d_nchwc_acc32.image2d") def conv2d_nchwc_acc32(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): """Compute conv2d with NCHWc layout""" - args={"shared" : False, "accumulator" : "float32"} - return compute_conv2d_NCHWc_KCRSk(data, kernel, strides, padding, dilation, out_dtype, args=args) + args = {"shared": False, "accumulator": "float32"} + return compute_conv2d_NCHWc_KCRSk( + data, kernel, strides, padding, dilation, out_dtype, args=args + ) + @autotvm.register_topi_schedule("conv2d_nchwc.image2d") def schedule_conv2d_nchwc(cfg, outs): return schedule_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc16") + @autotvm.register_topi_schedule("conv2d_nchwc_acc32.image2d") def schedule_conv2d_nchwc_acc32(cfg, outs): return schedule_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc32") + @autotvm.register_topi_compute("depthwise_conv2d_nchwc.image2d") def depthwise_conv2d_nchwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): """Compute depthwise_conv2d with NCHWc layout""" - args={"shared" : False, "accumulator" : "float16"} - return compute_depthwise_conv2d_NCHWc_KCRSk(data, kernel, strides, padding, dilation, out_dtype, args=args) + args = {"shared": False, "accumulator": "float16"} + return compute_depthwise_conv2d_NCHWc_KCRSk( + data, kernel, strides, padding, dilation, out_dtype, args=args + ) + @autotvm.register_topi_compute("depthwise_conv2d_nchwc_acc32.image2d") -def depthwise_conv2d_nchwc_acc32(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): +def depthwise_conv2d_nchwc_acc32( + cfg, data, kernel, strides, padding, dilation, out_dtype="float16" +): """Compute depthwise_conv2d with NCHWc layout""" - args={"shared" : False, "accumulator" : "float32"} - return compute_depthwise_conv2d_NCHWc_KCRSk(data, kernel, strides, padding, dilation, out_dtype, args=args) + args = {"shared": False, "accumulator": "float32"} + return compute_depthwise_conv2d_NCHWc_KCRSk( + data, kernel, strides, padding, dilation, out_dtype, args=args + ) + @autotvm.register_topi_schedule("depthwise_conv2d_nchwc.image2d") def schedule_depthwise_conv2d_nchwc(cfg, outs): return schedule_depthwise_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc16") + @autotvm.register_topi_schedule("depthwise_conv2d_nchwc_acc32.image2d") def schedule_depthwise_conv2d_nchwc_acc32(cfg, outs): return schedule_depthwise_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc32") + def schedule_conv2d_nchwc_impl(cfg, outs, tag): """Create the schedule for conv2d_nchw""" outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) + def _callback(op): if op.tag == tag: - args={"shared" : False} + args = {"shared": False} schedule_conv2d_NCHWc_KCRSk(cfg, s, op.output(0), args) traverse_inline(s, outs[0].op, _callback) return s + def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dtype=None, args={}): """Convolution operator in NCHWc layout. """ @@ -128,14 +148,26 @@ def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dty # texture: K|CRS|k Filter_tx = te.compute( (num_filter_chunk, channel * kernel_h * kernel_w, num_filter_block), - lambda ffc, crs, ffb: Filter[ffc, crs // (kernel_h * kernel_w), (crs // kernel_w) % kernel_h, crs % kernel_w, ffb], - name = "packed_filter" + lambda ffc, crs, ffb: Filter[ + ffc, crs // (kernel_h * kernel_w), (crs // kernel_w) % kernel_h, crs % kernel_w, ffb + ], + name="packed_filter", ) conv = te.compute( (batch, num_filter_chunk, out_height, out_width, num_filter_block), lambda nn, ffc, yy, xx, ffb: te.sum( - (temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb] - * Filter_tx[ffc, ((rcc * in_channel_block + rcb)*kernel_h + ry)*kernel_w + rx, ffb]).astype(args["accumulator"]), + ( + temp[ + nn, + rcc, + yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, + rcb, + ] + * Filter_tx[ + ffc, ((rcc * in_channel_block + rcb) * kernel_h + ry) * kernel_w + rx, ffb + ] + ).astype(args["accumulator"]), axis=[rcc, rcb, ry, rx], ), tag="conv2d_nchwc", @@ -144,13 +176,26 @@ def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dty conv = te.compute( (batch, num_filter_chunk, out_height, out_width, num_filter_block), lambda nn, ffc, yy, xx, ffb: te.sum( - (temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb] - * Filter[ffc, rcc * in_channel_block + rcb, ry, rx, ffb]).astype(args["accumulator"]), + ( + temp[ + nn, + rcc, + yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, + rcb, + ] + * Filter[ffc, rcc * in_channel_block + rcb, ry, rx, ffb] + ).astype(args["accumulator"]), axis=[rcc, rcb, ry, rx], ), tag="conv2d_nchwc", ) - return te.compute(conv.shape, lambda n,fc,y,x,fb: conv[n,fc,y,x,fb].astype("float16"), tag="cast_from_acc" + args["accumulator"][-2:]) + return te.compute( + conv.shape, + lambda n, fc, y, x, fb: conv[n, fc, y, x, fb].astype("float16"), + tag="cast_from_acc" + args["accumulator"][-2:], + ) + def schedule_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): """schedule optimized for batch size = 1""" @@ -209,6 +254,7 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): if autotvm.GLOBAL_SCOPE.in_tuning: AT = s.cache_read(pad_data, "global.texture", [OL]) WT = s.cache_read(kernel, "global.texture-weight", [OL]) + def copy_to_texture(stage): axes = s[stage].op.axis fused = s[stage].fuse(*axes[:-1]) @@ -216,6 +262,7 @@ def copy_to_texture(stage): s[stage].vectorize(axes[-1]) s[stage].bind(block, te.thread_axis("blockIdx.x")) s[stage].bind(thread, te.thread_axis("threadIdx.x")) + copy_to_texture(AT) copy_to_texture(WT) @@ -290,8 +337,7 @@ def copy_to_texture(stage): _, ICKHKW, _ = get_const_tuple(kernel.shape) else: _, IC, KH, KW, _ = get_const_tuple(kernel.shape) - ICKHKW = IC*KH*KW - + ICKHKW = IC * KH * KW if isinstance(N, int): cfg.add_flop(2 * N * OH * OW * OCC * OCB * ICKHKW) @@ -301,15 +347,19 @@ def schedule_depthwise_conv2d_nchwc_impl(cfg, outs, tag): """Create the schedule for depthwise conv2d_nchw4c_ohwi4o""" outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) + def _callback(op): if op.tag == tag: - args={"shared" : False} + args = {"shared": False} schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, op.output(0), args) traverse_inline(s, outs[0].op, _callback) return s -def compute_depthwise_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dtype=None, args={}): + +def compute_depthwise_conv2d_NCHWc_KCRSk( + Input, Filter, stride, padding, dilation, out_dtype=None, args={} +): """Depthwise convolution operator in NCHWc layout. """ if out_dtype is None: out_dtype = Input.dtype @@ -346,7 +396,6 @@ def compute_depthwise_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilatio ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") - if autotvm.GLOBAL_SCOPE.in_tuning: # NCHWc x CMRSc = [N,(C//4)M,OH,OW, 4c] # NCHWc x CMRS @@ -354,15 +403,29 @@ def compute_depthwise_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilatio # texture: C|MRS|c Filter_tx = te.compute( (channel_chunk, channel_multiplier * kernel_h * kernel_w, channel_block), - lambda ffc, mrs, ffb: Filter[ffc, mrs // (kernel_h * kernel_w), (mrs // kernel_w) % kernel_h, mrs % kernel_w, ffb], - name = "packed_filter" + lambda ffc, mrs, ffb: Filter[ + ffc, mrs // (kernel_h * kernel_w), (mrs // kernel_w) % kernel_h, mrs % kernel_w, ffb + ], + name="packed_filter", ) conv = te.compute( (batch, out_channel_chunk, out_height, out_width, channel_block), lambda nn, ffc, yy, xx, ffb: te.sum( - (temp[nn, ffc//channel_multiplier, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ffb] - * Filter_tx[ffc//channel_multiplier, ((ffc % channel_multiplier) * kernel_h + ry) * kernel_w + rx, ffb]).astype(args["accumulator"]), + ( + temp[ + nn, + ffc // channel_multiplier, + yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, + ffb, + ] + * Filter_tx[ + ffc // channel_multiplier, + ((ffc % channel_multiplier) * kernel_h + ry) * kernel_w + rx, + ffb, + ] + ).astype(args["accumulator"]), axis=[ry, rx], ), tag="depthwise_conv2d_nchwc_kcrsk_texture", @@ -371,13 +434,26 @@ def compute_depthwise_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilatio conv = te.compute( (batch, out_channel_chunk, out_height, out_width, channel_block), lambda nn, ffc, yy, xx, ffb: te.sum( - (temp[nn, ffc//channel_multiplier, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ffb] - * Filter[ffc//channel_multiplier, ffc % channel_multiplier, ry, rx, ffb]).astype(args["accumulator"]), + ( + temp[ + nn, + ffc // channel_multiplier, + yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, + ffb, + ] + * Filter[ffc // channel_multiplier, ffc % channel_multiplier, ry, rx, ffb] + ).astype(args["accumulator"]), axis=[ry, rx], ), tag="depthwise_conv2d_nchwc_kcrsk", ) - return te.compute(conv.shape, lambda n,ffc,y,x,ffb: conv[n,ffc,y,x,ffb].astype("float16"), tag="cast_from_acc" + args["accumulator"][-2:]) + return te.compute( + conv.shape, + lambda n, ffc, y, x, ffb: conv[n, ffc, y, x, ffb].astype("float16"), + tag="cast_from_acc" + args["accumulator"][-2:], + ) + def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): """schedule optimized for batch size = 1""" @@ -435,6 +511,7 @@ def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): if autotvm.GLOBAL_SCOPE.in_tuning: AT = s.cache_read(pad_data, "global.texture", [OL]) WT = s.cache_read(kernel, "global.texture-weight", [OL]) + def copy_to_texture(stage): axes = s[stage].op.axis fused = s[stage].fuse(*axes[:-1]) @@ -442,6 +519,7 @@ def copy_to_texture(stage): s[stage].vectorize(axes[-1]) s[stage].bind(block, te.thread_axis("blockIdx.x")) s[stage].bind(thread, te.thread_axis("threadIdx.x")) + copy_to_texture(AT) copy_to_texture(WT) @@ -485,7 +563,7 @@ def copy_to_texture(stage): s[OL].reorder(ryo, rxo, ryi, rxi, n, fc, y, x, fb) s[OL].vectorize(fb) - #s[OL].unroll() + # s[OL].unroll() if args["shared"]: s[AA].compute_at(s[OL], rxo) @@ -519,7 +597,7 @@ def copy_to_texture(stage): KHKW = MKHKW // M else: ICC, M, KH, KW, ICB = get_const_tuple(kernel.shape) - KHKW = KH*KW + KHKW = KH * KW if isinstance(N, int): cfg.add_flop(2 * N * OH * OW * OCC * OCB * KHKW) From 83a73f17af93710685dcdb14eb1fedf72ac5cd55 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Thu, 22 Jul 2021 10:44:15 -0700 Subject: [PATCH 13/15] Clean up keyword args in conv2d schedules. --- python/tvm/topi/adreno/conv2d.py | 53 +++++++++++++------------------- 1 file changed, 22 insertions(+), 31 deletions(-) diff --git a/python/tvm/topi/adreno/conv2d.py b/python/tvm/topi/adreno/conv2d.py index 3f457085f6e2..461b3670aa5e 100644 --- a/python/tvm/topi/adreno/conv2d.py +++ b/python/tvm/topi/adreno/conv2d.py @@ -28,18 +28,16 @@ @autotvm.register_topi_compute("conv2d_nchwc.image2d") def conv2d_nchwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): """Compute conv2d with NCHWc layout""" - args = {"shared": False, "accumulator": "float16"} return compute_conv2d_NCHWc_KCRSk( - data, kernel, strides, padding, dilation, out_dtype, args=args + data, kernel, strides, padding, dilation, out_dtype, shared=False, accumulator="float16" ) @autotvm.register_topi_compute("conv2d_nchwc_acc32.image2d") def conv2d_nchwc_acc32(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): """Compute conv2d with NCHWc layout""" - args = {"shared": False, "accumulator": "float32"} return compute_conv2d_NCHWc_KCRSk( - data, kernel, strides, padding, dilation, out_dtype, args=args + data, kernel, strides, padding, dilation, out_dtype, shared=False, accumulator="float32" ) @@ -56,9 +54,8 @@ def schedule_conv2d_nchwc_acc32(cfg, outs): @autotvm.register_topi_compute("depthwise_conv2d_nchwc.image2d") def depthwise_conv2d_nchwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"): """Compute depthwise_conv2d with NCHWc layout""" - args = {"shared": False, "accumulator": "float16"} return compute_depthwise_conv2d_NCHWc_KCRSk( - data, kernel, strides, padding, dilation, out_dtype, args=args + data, kernel, strides, padding, dilation, out_dtype, shared=False, accumulator="float16" ) @@ -67,9 +64,8 @@ def depthwise_conv2d_nchwc_acc32( cfg, data, kernel, strides, padding, dilation, out_dtype="float16" ): """Compute depthwise_conv2d with NCHWc layout""" - args = {"shared": False, "accumulator": "float32"} return compute_depthwise_conv2d_NCHWc_KCRSk( - data, kernel, strides, padding, dilation, out_dtype, args=args + data, kernel, strides, padding, dilation, out_dtype, shared=False, accumulator="float32" ) @@ -90,14 +86,13 @@ def schedule_conv2d_nchwc_impl(cfg, outs, tag): def _callback(op): if op.tag == tag: - args = {"shared": False} - schedule_conv2d_NCHWc_KCRSk(cfg, s, op.output(0), args) + schedule_conv2d_NCHWc_KCRSk(cfg, s, op.output(0), shared=False) traverse_inline(s, outs[0].op, _callback) return s -def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dtype=None, args={}): +def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dtype=None, **kwargs): """Convolution operator in NCHWc layout. """ if out_dtype is None: @@ -167,7 +162,7 @@ def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dty * Filter_tx[ ffc, ((rcc * in_channel_block + rcb) * kernel_h + ry) * kernel_w + rx, ffb ] - ).astype(args["accumulator"]), + ).astype(kwargs["accumulator"]), axis=[rcc, rcb, ry, rx], ), tag="conv2d_nchwc", @@ -185,7 +180,7 @@ def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dty rcb, ] * Filter[ffc, rcc * in_channel_block + rcb, ry, rx, ffb] - ).astype(args["accumulator"]), + ).astype(kwargs["accumulator"]), axis=[rcc, rcb, ry, rx], ), tag="conv2d_nchwc", @@ -193,11 +188,11 @@ def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dty return te.compute( conv.shape, lambda n, fc, y, x, fb: conv[n, fc, y, x, fb].astype("float16"), - tag="cast_from_acc" + args["accumulator"][-2:], + tag="cast_from_acc" + kwargs["accumulator"][-2:], ) -def schedule_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): +def schedule_conv2d_NCHWc_KCRSk(cfg, s, output, **kwargs): """schedule optimized for batch size = 1""" conv = output.op.input_tensors[0] @@ -266,10 +261,10 @@ def copy_to_texture(stage): copy_to_texture(AT) copy_to_texture(WT) - if args["shared"]: + if kwargs["shared"]: AA = s.cache_read(AT, "shared", [OL]) WW = s.cache_read(WT, "shared", [OL]) - elif args["shared"]: + elif kwargs["shared"]: AA = s.cache_read(pad_data, "shared", [OL]) WW = s.cache_read(kernel, "shared", [OL]) @@ -309,7 +304,7 @@ def copy_to_texture(stage): s[OL].vectorize(fb) s[OL].unroll(rcb) - if args["shared"]: + if kwargs["shared"]: s[AA].compute_at(s[OL], rxo) s[WW].compute_at(s[OL], rxo) # cooperative fetching @@ -350,15 +345,14 @@ def schedule_depthwise_conv2d_nchwc_impl(cfg, outs, tag): def _callback(op): if op.tag == tag: - args = {"shared": False} - schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, op.output(0), args) + schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, op.output(0), shared=False) traverse_inline(s, outs[0].op, _callback) return s def compute_depthwise_conv2d_NCHWc_KCRSk( - Input, Filter, stride, padding, dilation, out_dtype=None, args={} + Input, Filter, stride, padding, dilation, out_dtype=None, **kwargs ): """Depthwise convolution operator in NCHWc layout. """ if out_dtype is None: @@ -425,7 +419,7 @@ def compute_depthwise_conv2d_NCHWc_KCRSk( ((ffc % channel_multiplier) * kernel_h + ry) * kernel_w + rx, ffb, ] - ).astype(args["accumulator"]), + ).astype(kwargs["accumulator"]), axis=[ry, rx], ), tag="depthwise_conv2d_nchwc_kcrsk_texture", @@ -443,7 +437,7 @@ def compute_depthwise_conv2d_NCHWc_KCRSk( ffb, ] * Filter[ffc // channel_multiplier, ffc % channel_multiplier, ry, rx, ffb] - ).astype(args["accumulator"]), + ).astype(kwargs["accumulator"]), axis=[ry, rx], ), tag="depthwise_conv2d_nchwc_kcrsk", @@ -451,11 +445,11 @@ def compute_depthwise_conv2d_NCHWc_KCRSk( return te.compute( conv.shape, lambda n, ffc, y, x, ffb: conv[n, ffc, y, x, ffb].astype("float16"), - tag="cast_from_acc" + args["accumulator"][-2:], + tag="cast_from_acc" + kwargs["accumulator"][-2:], ) -def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output, args={}): +def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output, **kwargs): """schedule optimized for batch size = 1""" conv = output.op.input_tensors[0] @@ -523,10 +517,10 @@ def copy_to_texture(stage): copy_to_texture(AT) copy_to_texture(WT) - if args["shared"]: + if kwargs["shared"]: AA = s.cache_read(AT, "shared", [OL]) WW = s.cache_read(WT, "shared", [OL]) - elif args["shared"]: + elif kwargs["shared"]: AA = s.cache_read(pad_data, "shared", [OL]) WW = s.cache_read(kernel, "shared", [OL]) @@ -563,9 +557,8 @@ def copy_to_texture(stage): s[OL].reorder(ryo, rxo, ryi, rxi, n, fc, y, x, fb) s[OL].vectorize(fb) - # s[OL].unroll() - if args["shared"]: + if kwargs["shared"]: s[AA].compute_at(s[OL], rxo) s[WW].compute_at(s[OL], rxo) # cooperative fetching @@ -589,8 +582,6 @@ def copy_to_texture(stage): s[output].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val) N, OCC, OH, OW, OCB = get_const_tuple(output.shape) - # OC = OCC * OCB = IC * M - # M = OC // IC == (OCC * OCB) // ICC * ICB if autotvm.GLOBAL_SCOPE.in_tuning: ICC, MKHKW, ICB = get_const_tuple(kernel.shape) M = (OCC * OCB) // (ICC * ICB) From 257fe455e57ed2f30eb937c979fcc535ae3f47e6 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Thu, 22 Jul 2021 10:46:41 -0700 Subject: [PATCH 14/15] Lint --- python/tvm/relay/op/strategy/adreno.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py index b65cfe0b014f..dc3b7cc95fff 100644 --- a/python/tvm/relay/op/strategy/adreno.py +++ b/python/tvm/relay/op/strategy/adreno.py @@ -16,7 +16,6 @@ # under the License. """Definition of adreno operator strategy.""" # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import -import re from tvm import topi from .generic import * from .. import op as _op @@ -28,7 +27,6 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): strategy = _op.OpStrategy() data, kernel = inputs dilation_h, dilation_w = attrs.get_int_tuple("dilation") - stride_h, stride_w = attrs.get_int_tuple("strides") groups = attrs.groups data_layout = attrs.data_layout kernel_layout = attrs.kernel_layout From c4b03f346bab94711a3ed7cbe1d0642283eb4af8 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Thu, 22 Jul 2021 10:52:14 -0700 Subject: [PATCH 15/15] Loosen fp16 restrictions. --- python/tvm/relay/op/strategy/adreno.py | 3 --- python/tvm/topi/adreno/conv2d.py | 4 ++-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py index dc3b7cc95fff..fe375d760f96 100644 --- a/python/tvm/relay/op/strategy/adreno.py +++ b/python/tvm/relay/op/strategy/adreno.py @@ -30,9 +30,6 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): groups = attrs.groups data_layout = attrs.data_layout kernel_layout = attrs.kernel_layout - assert ( - out_type.dtype == "float16" - ), "No float32 input/output tensor support is currently provided for Adreno GPU" if dilation_h < 1 or dilation_w < 1: raise ValueError("dilation should be positive value") diff --git a/python/tvm/topi/adreno/conv2d.py b/python/tvm/topi/adreno/conv2d.py index 461b3670aa5e..cfffec377bcd 100644 --- a/python/tvm/topi/adreno/conv2d.py +++ b/python/tvm/topi/adreno/conv2d.py @@ -187,7 +187,7 @@ def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dty ) return te.compute( conv.shape, - lambda n, fc, y, x, fb: conv[n, fc, y, x, fb].astype("float16"), + lambda n, fc, y, x, fb: conv[n, fc, y, x, fb].astype(out_dtype), tag="cast_from_acc" + kwargs["accumulator"][-2:], ) @@ -444,7 +444,7 @@ def compute_depthwise_conv2d_NCHWc_KCRSk( ) return te.compute( conv.shape, - lambda n, ffc, y, x, ffb: conv[n, ffc, y, x, ffb].astype("float16"), + lambda n, ffc, y, x, ffb: conv[n, ffc, y, x, ffb].astype(out_dtype), tag="cast_from_acc" + kwargs["accumulator"][-2:], )