From 59c2487fb45f361ca2b0711ee8a546d56c588f01 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Wed, 20 Sep 2023 11:14:48 +0530 Subject: [PATCH 01/18] [TOPI][ADRENO] Add conv2d transpose nchw texture schedule Added the conv2d transpose strategy for adreno target and enable the optimized schedule. --- python/tvm/relay/op/strategy/adreno.py | 54 +++ python/tvm/topi/adreno/__init__.py | 1 + .../tvm/topi/adreno/conv2d_transpose_nchw.py | 390 ++++++++++++++++++ python/tvm/topi/adreno/utils.py | 23 ++ .../test_conv2d_transpose_nchw_texture.py | 73 ++++ .../opencl_texture/utils/adreno_utils.py | 5 +- 6 files changed, 545 insertions(+), 1 deletion(-) create mode 100644 python/tvm/topi/adreno/conv2d_transpose_nchw.py create mode 100644 tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py index c180eeec7414..7f8e8e0875ba 100644 --- a/python/tvm/relay/op/strategy/adreno.py +++ b/python/tvm/relay/op/strategy/adreno.py @@ -214,6 +214,60 @@ def conv2d_winograd_without_weight_transform_strategy_adreno(attrs, inputs, out_ raise RuntimeError(f"Unsupported conv2d_winograd_without_weight_transform layout {layout}") return strategy +@conv2d_transpose_strategy.register("adreno") +def conv2d_transpose_strategy_adreno(attrs, inputs, out_type, target): + """conv2d_transpose adreno strategy""" + strategy = _op.OpStrategy() + data, kernel = inputs + dilation = attrs.get_int_tuple("dilation") + stride_h, stride_w = attrs.get_int_tuple("strides") + groups = attrs.groups + data_layout = attrs.data_layout + kernel_layout = attrs.kernel_layout + assert dilation == (1, 1), "not support dilate now" + + if ( + (groups == 1) and ( + (data_layout == "NCHW" and kernel_layout == "IOHW") + or (data_layout == "NCHW4c" and kernel_layout == "IOHW4o") + or (data_layout == "NCHW" and kernel_layout == "IOHW4o") + ) + ): + if len(kernel.shape) == 4: + oc, _, kh, kw = get_const_tuple(kernel.shape) + else: + oc, _, kh, kw, _ = get_const_tuple(kernel.shape) + # We cannot use textures for case than number of channels is less than 4. + # So, we use compute functions from cuda. + if len(kernel.shape) == 4 and oc < 4: + strategy.add_implementation( + wrap_compute_conv2d_transpose(topi.cuda.conv2d_transpose_nchw), + wrap_topi_schedule(topi.cuda.schedule_conv2d_transpose_nchw), + name="conv2d_transpose_nchw.cuda", + ) + return strategy + strategy.add_implementation( + wrap_compute_conv2d_transpose(topi.adreno.conv2d_transpose_nchwc), + wrap_topi_schedule(topi.adreno.schedule_conv2d_transpose_nchwc), + name="conv2d_transpose_nchwc.image2d", + plevel=10, + ) + elif data_layout == "NCHW": + strategy.add_implementation( + wrap_compute_conv2d_transpose(topi.cuda.conv2d_transpose_nchw, has_groups=True), + wrap_topi_schedule(topi.cuda.schedule_conv2d_transpose_nchw), + name="conv2d_transpose_nchw.cuda", + ) + else: + raise RuntimeError( + "Layout not supported: (" + + data_layout + + ", " + + kernel_layout + + ") - only support NCHW, NCHW4c / IOHW4o layouts for conv2d_transpose" + ) + return strategy + @schedule_pool.register("adreno") def schedule_pool_adreno(attrs, outs, target): diff --git a/python/tvm/topi/adreno/__init__.py b/python/tvm/topi/adreno/__init__.py index 55bfbee2a8d7..f54c6f9bd6fe 100644 --- a/python/tvm/topi/adreno/__init__.py +++ b/python/tvm/topi/adreno/__init__.py @@ -27,3 +27,4 @@ from .conv2d_nhwc_winograd import * from .injective import schedule_injective from .reduction import * +from .conv2d_transpose_nchw import * diff --git a/python/tvm/topi/adreno/conv2d_transpose_nchw.py b/python/tvm/topi/adreno/conv2d_transpose_nchw.py new file mode 100644 index 000000000000..50ea11161b12 --- /dev/null +++ b/python/tvm/topi/adreno/conv2d_transpose_nchw.py @@ -0,0 +1,390 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return +"""conv2d_transpose nchw schedule on Qualcomm Adreno GPU""" +import tvm +from tvm import te +from tvm import autotvm +from .. import nn +from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity + +from ..utils import get_const_tuple, traverse_inline +from .utils import ( + split_to_chunks, + pack_input, + pack_filter, + expand_spatial_dimensions, + add_pad, + bind_data_copy, + get_default_conv2d_config, + get_texture_storage, +) + + +@autotvm.register_topi_compute("conv2d_transpose_nchwc.image2d") +def conv2d_transpose_nchwc(cfg, Input, Filter, stride, padding, out_dtype, output_padding, groups=1): + """ + Transposed Convolution operator in NCHWc layout. + Algo: + """ + + if out_dtype is None: + out_dtype = Input.dtype + assert isinstance(stride, int) or len(stride) == 2 + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + outpad_height, outpad_width = output_padding + assert outpad_height < stride_h and outpad_width < stride_w + + convert_from4d = False + if len(Input.shape) == 4: + batch, in_channels, in_height, in_width = Input.shape + in_channel_chunks, in_channel_block, in_channel_tail = split_to_chunks(in_channels, 4) + + if autotvm.GLOBAL_SCOPE.in_tuning: + dshape = (batch, in_channel_chunks, in_height, in_width, in_channel_block) + Input = tvm.te.placeholder(dshape, Input.dtype, name="data_placeholder") + else: + Input = pack_input( + Input, + "NCHW", + batch, + in_channel_chunks, + in_channel_block, + in_channel_tail, + in_height, + in_width, + ) + else: + batch, in_channel_chunks, in_height, in_width, in_channel_block = Input.shape + + if len(Filter.shape) == 4: + in_filter_channels, out_channels, kernel_h, kernel_w = Filter.shape + out_channel_chunks, out_channel_block, out_channel_tail = split_to_chunks(out_channels, 4) + + if autotvm.GLOBAL_SCOPE.in_tuning: + kshape = (in_filter_channels, out_channel_chunks, kernel_h, kernel_w, out_channel_block) + Filter = tvm.te.placeholder(kshape, Filter.dtype, name="kernel_placeholder") + else: + convert_from4d = True + Filter = pack_filter( + Filter, + "IOHW", + out_channel_chunks, + out_channel_block, + out_channel_tail, + in_filter_channels, + in_channel_chunks, + in_channel_block, + in_channel_tail, + kernel_h, + kernel_w, + ) + else: + in_filter_channels, out_channel_chunks, kernel_h, kernel_w, out_channel_block = Filter.shape + + """ + assert ( + in_channels % groups == 0 + ), f"input channels {inp_channels} must divide group size {groups}": + """ + + cfg.stride = stride + + pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple( + padding, (kernel_h, kernel_w) + ) + + out_width_orig = out_width = (in_width - 1) * stride_w + kernel_w - pad_left - pad_right + outpad_width + pad_left = kernel_w - 1 - pad_left + pad_right = kernel_w - 1 - pad_right + outpad_width + dilated_width = stride_w * (in_width - 1) + 1 + + out_height_orig = out_height = ( + (in_height - 1) * stride_h + kernel_h - pad_top - pad_bottom + outpad_height + ) + pad_top = kernel_h - 1 - pad_top + pad_bottom = kernel_h - 1 - pad_bottom + outpad_height + dilated_height = stride_h * (in_height - 1) + 1 + + if out_height % 2 != 0: + out_height += 1 + if out_width % 2 != 0: + out_width += 1 + + if out_height % 4 != 0: + out_height += 2 + if out_width % 4 != 0: + out_width += 2 + + # compute pad + temp = te.compute( + ( + batch, + in_channel_chunks, + pad_top + dilated_height + pad_bottom, + pad_left + dilated_width + pad_right, + in_channel_block, + ), + lambda n, c, y, x, cb: tvm.tir.if_then_else( + tvm.tir.all( + x >= pad_left, + x < pad_left + dilated_width, + tvm.tir.indexmod(x - pad_left, stride_w).equal(0), + y >= pad_top, + y < pad_top + dilated_height, + tvm.tir.indexmod(y - pad_top, stride_h).equal(0), + ), + Input[ + n, + c, + tvm.tir.indexdiv(y - pad_top, stride_h), + tvm.tir.indexdiv(x - pad_left, stride_w), + cb, + ], + tvm.tir.const(0.0, Input.dtype), + ), + name="pad_temp", + ) + + # compute transposed conv + dcc = te.reduce_axis((0, in_channel_chunks), name="dcc") + dcb = te.reduce_axis((0, in_channel_block), name="dcb") + dh = te.reduce_axis((0, kernel_h), name="dh") + dw = te.reduce_axis((0, kernel_w), name="dw") + conv = te.compute( + (batch, out_channel_chunks, out_height, out_width, out_channel_block), + lambda b, c, h, w, cb: te.sum( + temp[b, c // out_channel_chunks * (in_channel_chunks) + dcc, h + dh, w + dw, dcb].astype( + out_dtype + ) + * Filter[ + dcc * in_channel_block + dcb, + c % out_channel_chunks, + kernel_h - 1 - dh, + kernel_w - 1 - dw, + cb, + ].astype(out_dtype), + axis=[dcc, dcb, dh, dw], + ), + tag="conv2d_transpose_nchwc", + ) + + + if convert_from4d and not autotvm.GLOBAL_SCOPE.in_tuning: + dummy_cast = te.compute( + (batch, out_channel_chunks, out_height_orig, out_width_orig, out_channel_block), + lambda n, fc, y, x, fb: conv[n, fc, y, x, fb].astype(out_dtype), + tag="dummy_cast", + ) + return te.compute( + (batch, out_channels, out_height_orig, out_width_orig), + lambda n, c, y, x: dummy_cast[n, c // out_channel_block, y, x, c % out_channel_block], + tag="adreno_conv2d_transpose_latest_op", + ) + else: + return te.compute( + (batch, out_channel_chunks, out_height_orig, out_width_orig, out_channel_block), + lambda n, ffc, y, x, ffb: conv[n, ffc, y, x, ffb].astype(out_dtype), + tag="adreno_conv2d_transpose_latest_op", + ) + +@autotvm.register_topi_schedule("conv2d_transpose_nchwc.image2d") +def schedule_conv2d_transpose_nchwc(cfg, outs): + """Create the schedule for conv2d_nchw""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == "adreno_conv2d_transpose_latest_op": + schedule_conv2d_transpose_NCHWc(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s + +def schedule_conv2d_transpose_NCHWc(cfg, s, output): + """ + schedule optimized for batch size = 1 + + Algo: + """ + latest = s.outputs[0].output(0) + if len(latest.op.axis) == 4: + latest_blocked = dummy = output.op.input_tensors[0] + conv = dummy.op.input_tensors[0] + else: + conv = output.op.input_tensors[0] + latest_blocked = latest + + pad_data, kernel = s[conv].op.input_tensors + filter_pack_rt = bool( + isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in kernel.op.tag + ) + + if "pad_temp" in pad_data.op.name: + input_pad_temp = pad_data.op.input_tensors[0] + else: + input_pad_temp = pad_data + + input_pack_rt = bool( + isinstance(input_pad_temp.op, tvm.te.ComputeOp) and "input_pack" in input_pad_temp.op.tag + ) + + ##### space definition begin ##### + n, fc, y, x, fb = s[conv].op.axis + rcc, rcb, ry, rx = s[conv].op.reduce_axis + + if conv.shape[1] % 2 == 0: + min_threads_div = 2 + else: + min_threads_div = 1 + cfg.define_split( + "tile_fc", + fc, + num_outputs=3, + filter=lambda entity: entity.size[1] <= 8 + and entity.size[2] >= min_threads_div + and entity.size[2] < 256, + ) + cfg.define_split( + "tile_y", + y, + num_outputs=3, + filter=lambda entity: entity.size[1] <= 8 and entity.size[2] <= 16, + ) + cfg.define_split( + "tile_x", + x, + num_outputs=3, + filter=lambda entity: entity.size[1] <= 8 and entity.size[2] <= 16, + ) + + cfg.define_split("tile_rcc", rcc, num_outputs=2) + cfg.define_split("tile_ry", ry, num_outputs=2) + cfg.define_split("tile_rx", rx, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [0, 64]) + cfg.define_knob("unroll_explicit", [0, 1]) + cfg.multi_filter( + filter=lambda entity: ( # pylint: disable=chained-comparison + entity["tile_fc"].size[1] * entity["tile_y"].size[1] * entity["tile_x"].size[1] + ) + <= 24 + and 32 + <= (entity["tile_fc"].size[2] * entity["tile_y"].size[2] * entity["tile_x"].size[2]) + < 1024 + ) + if cfg.is_fallback: + get_default_conv2d_config(cfg, conv.shape[1], conv.shape[2], conv.shape[3]) + ##### space definition end ##### + + pad_data, kernel = s[conv].op.input_tensors + # There are several conditions that have to be handled: + # 1. If we are in the tuning, we always add cache read for data to main conv kernel + # to get texture in tuning opencl kernel + # 2. If we are repacking input in runtime, we should always explicit schedule this one more + # stage of data copy from 4d to 5d (referred as pack_data). + # 3. If we have pad (independently if we have runtime repack or not) we should inline it in the + # cache_read("texture") + if autotvm.GLOBAL_SCOPE.in_tuning or input_pack_rt: + if autotvm.GLOBAL_SCOPE.in_tuning: + if "pad_temp" in pad_data.op.name: + s[pad_data].compute_inline() + else: + if "pad_temp" in pad_data.op.name: + pack_data = pad_data.op.input_tensors[0] + bind_data_copy(s[pack_data]) + s[pad_data].compute_inline() + else: + pack_data = pad_data + bind_data_copy(s[pack_data]) + + AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv]) + bind_data_copy(s[AT]) + elif "pad_temp" in pad_data.op.name: + s[pad_data].compute_inline() + # create cache stage + AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv]) + bind_data_copy(s[AT]) + + if autotvm.GLOBAL_SCOPE.in_tuning or filter_pack_rt: + if not autotvm.GLOBAL_SCOPE.in_tuning: + bind_data_copy(s[kernel]) + if kernel.shape[2] == 1 and kernel.shape[3] == 1: + WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv]) + bind_data_copy(s[WT]) + + s[conv].set_scope("local") + if latest_blocked == latest and output != latest: + s[output].compute_inline() + + # tile and bind spatial axes + n, fc, y, x, fb = s[latest_blocked].op.axis + + kernel_scope, n = s[latest_blocked].split(n, nparts=1) + + bf, vf, tf = cfg["tile_fc"].apply(s, latest_blocked, fc) + by, vy, ty = cfg["tile_y"].apply(s, latest_blocked, y) + bx, vx, tx = cfg["tile_x"].apply(s, latest_blocked, x) + + bf = s[latest_blocked].fuse(n, bf) + s[latest_blocked].bind(bf, te.thread_axis("blockIdx.z")) + s[latest_blocked].bind(by, te.thread_axis("blockIdx.y")) + s[latest_blocked].bind(bx, te.thread_axis("blockIdx.x")) + s[latest_blocked].bind(vf, te.thread_axis("vthread")) + s[latest_blocked].bind(vy, te.thread_axis("vthread")) + s[latest_blocked].bind(vx, te.thread_axis("vthread")) + s[latest_blocked].bind(tf, te.thread_axis("threadIdx.z")) + s[latest_blocked].bind(ty, te.thread_axis("threadIdx.y")) + s[latest_blocked].bind(tx, te.thread_axis("threadIdx.x")) + s[latest_blocked].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fb) + s[latest_blocked].vectorize(fb) + + s[conv].compute_at(s[latest_blocked], tx) + + # tile reduction axes + n, fc, y, x, fb = s[conv].op.axis + + rcc, rcb, ry, rx = s[conv].op.reduce_axis + rco, rci = cfg["tile_rcc"].apply(s, conv, rcc) + ryo, ryi = cfg["tile_ry"].apply(s, conv, ry) + rxo, rxi = cfg["tile_rx"].apply(s, conv, rx) + + s[conv].reorder(rco, ryo, rxo, rci, ryi, rxi, rcb, n, fc, y, x, fb) + s[conv].vectorize(fb) + s[conv].unroll(rcb) + + # unroll + s[latest_blocked].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + s[latest_blocked].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val) + + if latest_blocked != latest: + s[latest].compute_root() + bind_data_copy(s[latest], 1) + if latest != output: + s[output].compute_inline() + + N, OCC, OH, OW, OCB = get_const_tuple(latest_blocked.shape) + _, IC, KH, KW, _ = get_const_tuple(kernel.shape) + ICKHKW = IC * KH * KW + + if isinstance(N, int): + cfg.add_flop(2 * N * OH * OW * OCC * OCB * ICKHKW) + + diff --git a/python/tvm/topi/adreno/utils.py b/python/tvm/topi/adreno/utils.py index 698a306514db..a42cbeeb773b 100644 --- a/python/tvm/topi/adreno/utils.py +++ b/python/tvm/topi/adreno/utils.py @@ -281,6 +281,22 @@ def _reorder_weights_hwio(*indices): Filter[indices[0], indices[1], indices[2], indices[3] * out_block + indices[4]], ) + def _reorder_weights_iohw(*indices): + conditionA = [] + conditionA.append(indices[1] == out_chunks - 1) + conditionA.append(indices[4] >= out_original_tail) + conditionAT = tvm.tir.all(*conditionA) + + conditionO = [] + conditionO.append(conditionAT) + conditionO.append(indices[0] >= in_chunks * in_block + in_original_tail) + conditionOT = tvm.tir.any(*conditionO) + return tvm.tir.if_then_else( + conditionOT, + pad_value, + Filter[indices[0], indices[1] * out_block + indices[4], indices[2], indices[3]], + ) + if in_filter_channels == 1: if layout == "OIHW": reordered_filter = te.compute( @@ -313,6 +329,13 @@ def _reorder_weights_hwio(*indices): name="filter_pack", tag="filter_pack", ) + elif layout == "IOHW": + reordered_filter = te.compute( + [in_filter_channels, out_chunks, kernel_h, kernel_w, out_block], + _reorder_weights_iohw, + name="filter_pack", + tag="filter_pack", + ) elif layout == "HWIO": reordered_filter = te.compute( [kernel_h, kernel_w, in_filter_channels, out_chunks, out_block], diff --git a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py new file mode 100644 index 000000000000..6f80f5c78a52 --- /dev/null +++ b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re +import tvm +import numpy as np +from tvm import relay +from tvm.relay import testing +from tvm.contrib import utils +from utils.adreno_utils import gpu_preprocess, build_run_compare, build_run_compare_vm +import pytest + +executor_type = tvm.testing.parameter("ge") +dtype = tvm.testing.parameter("float32") + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): + input_shape = (1, 256, 100, 100) + filter_shape = (256, 64, 4, 4) + channels = 64 + kernel_size = (4, 4) + strides = (2, 2) + padding = (1, 1, 1, 1) + x = relay.var("data", shape=input_shape, dtype=dtype) + w = relay.var("weight", shape=filter_shape, dtype=dtype) + + y = relay.nn.conv2d_transpose( + x, + w, + channels=channels, + kernel_size=kernel_size, + strides=strides, + padding=padding, + kernel_layout="IOHW", + data_layout="NCHW", + ) + + mod = relay.Function([x, w], y) + np.random.seed(0) + initializer = relay.testing.init.Xavier() + filter_data = np.zeros(filter_shape).astype(dtype) + initializer("weight", filter_data) + params1 = { + "weight": tvm.nd.array(filter_data), + } + + if executor_type == "ge": + build_run_compare( + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, [], gpu_preprocess + ) + else: + build_run_compare_vm( + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, [], gpu_preprocess + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relay/opencl_texture/utils/adreno_utils.py b/tests/python/relay/opencl_texture/utils/adreno_utils.py index d9e52f8847a7..233a93fa66f0 100644 --- a/tests/python/relay/opencl_texture/utils/adreno_utils.py +++ b/tests/python/relay/opencl_texture/utils/adreno_utils.py @@ -200,7 +200,10 @@ def build_run_compare_vm( def gpu_preprocess(tvm_mod): layout_config = relay.transform.LayoutConfig() - desired_layouts = {"nn.conv2d": ["NCHW4c", "OIHW4o"]} + desired_layouts = { + "nn.conv2d": ["NCHW4c", "OIHW4o"], + "nn.conv2d_transpose": ["NCHW4c", "IOHW4o"] + } with layout_config: seq = tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)]) with tvm.transform.PassContext(opt_level=3): From 25d446a9a77f49426766a16e0041c289b07989a8 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Wed, 20 Sep 2023 11:38:35 +0530 Subject: [PATCH 02/18] Fix the whitespace lint error --- python/tvm/relay/op/strategy/adreno.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py index 7f8e8e0875ba..51dcb9aa0c1a 100644 --- a/python/tvm/relay/op/strategy/adreno.py +++ b/python/tvm/relay/op/strategy/adreno.py @@ -224,7 +224,7 @@ def conv2d_transpose_strategy_adreno(attrs, inputs, out_type, target): groups = attrs.groups data_layout = attrs.data_layout kernel_layout = attrs.kernel_layout - assert dilation == (1, 1), "not support dilate now" + assert dilation == (1, 1), "not support dilate now" if ( (groups == 1) and ( From 31c172922aaf497baea6cd75476edcb5df4248b8 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Wed, 20 Sep 2023 14:54:52 +0530 Subject: [PATCH 03/18] Fix lint errors --- python/tvm/relay/op/strategy/adreno.py | 11 ++++---- .../tvm/topi/adreno/conv2d_transpose_nchw.py | 25 ++++++++++--------- .../test_conv2d_transpose_nchw_texture.py | 1 + .../opencl_texture/utils/adreno_utils.py | 6 ++--- 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py index 51dcb9aa0c1a..f86aa6923228 100644 --- a/python/tvm/relay/op/strategy/adreno.py +++ b/python/tvm/relay/op/strategy/adreno.py @@ -214,6 +214,7 @@ def conv2d_winograd_without_weight_transform_strategy_adreno(attrs, inputs, out_ raise RuntimeError(f"Unsupported conv2d_winograd_without_weight_transform layout {layout}") return strategy + @conv2d_transpose_strategy.register("adreno") def conv2d_transpose_strategy_adreno(attrs, inputs, out_type, target): """conv2d_transpose adreno strategy""" @@ -226,12 +227,10 @@ def conv2d_transpose_strategy_adreno(attrs, inputs, out_type, target): kernel_layout = attrs.kernel_layout assert dilation == (1, 1), "not support dilate now" - if ( - (groups == 1) and ( - (data_layout == "NCHW" and kernel_layout == "IOHW") - or (data_layout == "NCHW4c" and kernel_layout == "IOHW4o") - or (data_layout == "NCHW" and kernel_layout == "IOHW4o") - ) + if (groups == 1) and ( + (data_layout == "NCHW" and kernel_layout == "IOHW") + or (data_layout == "NCHW4c" and kernel_layout == "IOHW4o") + or (data_layout == "NCHW" and kernel_layout == "IOHW4o") ): if len(kernel.shape) == 4: oc, _, kh, kw = get_const_tuple(kernel.shape) diff --git a/python/tvm/topi/adreno/conv2d_transpose_nchw.py b/python/tvm/topi/adreno/conv2d_transpose_nchw.py index 50ea11161b12..e57b83416102 100644 --- a/python/tvm/topi/adreno/conv2d_transpose_nchw.py +++ b/python/tvm/topi/adreno/conv2d_transpose_nchw.py @@ -36,7 +36,9 @@ @autotvm.register_topi_compute("conv2d_transpose_nchwc.image2d") -def conv2d_transpose_nchwc(cfg, Input, Filter, stride, padding, out_dtype, output_padding, groups=1): +def conv2d_transpose_nchwc( + cfg, Input, Filter, stride, padding, out_dtype, output_padding, groups=1 +): """ Transposed Convolution operator in NCHWc layout. Algo: @@ -108,11 +110,11 @@ def conv2d_transpose_nchwc(cfg, Input, Filter, stride, padding, out_dtype, outpu cfg.stride = stride - pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple( - padding, (kernel_h, kernel_w) - ) + pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple(padding, (kernel_h, kernel_w)) - out_width_orig = out_width = (in_width - 1) * stride_w + kernel_w - pad_left - pad_right + outpad_width + out_width_orig = out_width = ( + (in_width - 1) * stride_w + kernel_w - pad_left - pad_right + outpad_width + ) pad_left = kernel_w - 1 - pad_left pad_right = kernel_w - 1 - pad_right + outpad_width dilated_width = stride_w * (in_width - 1) + 1 @@ -172,9 +174,9 @@ def conv2d_transpose_nchwc(cfg, Input, Filter, stride, padding, out_dtype, outpu conv = te.compute( (batch, out_channel_chunks, out_height, out_width, out_channel_block), lambda b, c, h, w, cb: te.sum( - temp[b, c // out_channel_chunks * (in_channel_chunks) + dcc, h + dh, w + dw, dcb].astype( - out_dtype - ) + temp[ + b, c // out_channel_chunks * (in_channel_chunks) + dcc, h + dh, w + dw, dcb + ].astype(out_dtype ) * Filter[ dcc * in_channel_block + dcb, c % out_channel_chunks, @@ -187,7 +189,6 @@ def conv2d_transpose_nchwc(cfg, Input, Filter, stride, padding, out_dtype, outpu tag="conv2d_transpose_nchwc", ) - if convert_from4d and not autotvm.GLOBAL_SCOPE.in_tuning: dummy_cast = te.compute( (batch, out_channel_chunks, out_height_orig, out_width_orig, out_channel_block), @@ -206,6 +207,7 @@ def conv2d_transpose_nchwc(cfg, Input, Filter, stride, padding, out_dtype, outpu tag="adreno_conv2d_transpose_latest_op", ) + @autotvm.register_topi_schedule("conv2d_transpose_nchwc.image2d") def schedule_conv2d_transpose_nchwc(cfg, outs): """Create the schedule for conv2d_nchw""" @@ -219,6 +221,7 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + def schedule_conv2d_transpose_NCHWc(cfg, s, output): """ schedule optimized for batch size = 1 @@ -385,6 +388,4 @@ def schedule_conv2d_transpose_NCHWc(cfg, s, output): ICKHKW = IC * KH * KW if isinstance(N, int): - cfg.add_flop(2 * N * OH * OW * OCC * OCB * ICKHKW) - - + cfg.add_flop(2 * N * OH * OW * OCC * OCB * ICKHKW) \ No newline at end of file diff --git a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py index 6f80f5c78a52..6f2e15942fc0 100644 --- a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py +++ b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py @@ -27,6 +27,7 @@ executor_type = tvm.testing.parameter("ge") dtype = tvm.testing.parameter("float32") + @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): diff --git a/tests/python/relay/opencl_texture/utils/adreno_utils.py b/tests/python/relay/opencl_texture/utils/adreno_utils.py index 233a93fa66f0..76a9835205f9 100644 --- a/tests/python/relay/opencl_texture/utils/adreno_utils.py +++ b/tests/python/relay/opencl_texture/utils/adreno_utils.py @@ -201,9 +201,9 @@ def build_run_compare_vm( def gpu_preprocess(tvm_mod): layout_config = relay.transform.LayoutConfig() desired_layouts = { - "nn.conv2d": ["NCHW4c", "OIHW4o"], - "nn.conv2d_transpose": ["NCHW4c", "IOHW4o"] - } + "nn.conv2d": ["NCHW4c", "OIHW4o"], + "nn.conv2d_transpose": ["NCHW4c", "IOHW4o"] + } with layout_config: seq = tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)]) with tvm.transform.PassContext(opt_level=3): From 2db6a348b6f1d4300597a80c99a454cc3f084b65 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Wed, 20 Sep 2023 15:55:50 +0530 Subject: [PATCH 04/18] Fix whitespace lint error --- python/tvm/topi/adreno/conv2d_transpose_nchw.py | 10 ++-------- .../python/relay/opencl_texture/utils/adreno_utils.py | 2 +- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/python/tvm/topi/adreno/conv2d_transpose_nchw.py b/python/tvm/topi/adreno/conv2d_transpose_nchw.py index e57b83416102..5d2bce483c3b 100644 --- a/python/tvm/topi/adreno/conv2d_transpose_nchw.py +++ b/python/tvm/topi/adreno/conv2d_transpose_nchw.py @@ -102,12 +102,6 @@ def conv2d_transpose_nchwc( else: in_filter_channels, out_channel_chunks, kernel_h, kernel_w, out_channel_block = Filter.shape - """ - assert ( - in_channels % groups == 0 - ), f"input channels {inp_channels} must divide group size {groups}": - """ - cfg.stride = stride pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple(padding, (kernel_h, kernel_w)) @@ -176,7 +170,7 @@ def conv2d_transpose_nchwc( lambda b, c, h, w, cb: te.sum( temp[ b, c // out_channel_chunks * (in_channel_chunks) + dcc, h + dh, w + dw, dcb - ].astype(out_dtype ) + ].astype(out_dtype) * Filter[ dcc * in_channel_block + dcb, c % out_channel_chunks, @@ -388,4 +382,4 @@ def schedule_conv2d_transpose_NCHWc(cfg, s, output): ICKHKW = IC * KH * KW if isinstance(N, int): - cfg.add_flop(2 * N * OH * OW * OCC * OCB * ICKHKW) \ No newline at end of file + cfg.add_flop(2 * N * OH * OW * OCC * OCB * ICKHKW) diff --git a/tests/python/relay/opencl_texture/utils/adreno_utils.py b/tests/python/relay/opencl_texture/utils/adreno_utils.py index 76a9835205f9..21bdfbdee3cb 100644 --- a/tests/python/relay/opencl_texture/utils/adreno_utils.py +++ b/tests/python/relay/opencl_texture/utils/adreno_utils.py @@ -202,7 +202,7 @@ def gpu_preprocess(tvm_mod): layout_config = relay.transform.LayoutConfig() desired_layouts = { "nn.conv2d": ["NCHW4c", "OIHW4o"], - "nn.conv2d_transpose": ["NCHW4c", "IOHW4o"] + "nn.conv2d_transpose": ["NCHW4c", "IOHW4o"], } with layout_config: seq = tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)]) From 5d37af42319b81780d5ddd21ed5bc2b99cda8aa7 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Mon, 25 Sep 2023 11:51:49 +0530 Subject: [PATCH 05/18] Removed unused variables --- python/tvm/relay/op/strategy/adreno.py | 7 +++---- python/tvm/topi/adreno/conv2d_transpose_nchw.py | 4 +--- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py index f86aa6923228..205355b40f1d 100644 --- a/python/tvm/relay/op/strategy/adreno.py +++ b/python/tvm/relay/op/strategy/adreno.py @@ -219,9 +219,8 @@ def conv2d_winograd_without_weight_transform_strategy_adreno(attrs, inputs, out_ def conv2d_transpose_strategy_adreno(attrs, inputs, out_type, target): """conv2d_transpose adreno strategy""" strategy = _op.OpStrategy() - data, kernel = inputs + _, kernel = inputs dilation = attrs.get_int_tuple("dilation") - stride_h, stride_w = attrs.get_int_tuple("strides") groups = attrs.groups data_layout = attrs.data_layout kernel_layout = attrs.kernel_layout @@ -233,9 +232,9 @@ def conv2d_transpose_strategy_adreno(attrs, inputs, out_type, target): or (data_layout == "NCHW" and kernel_layout == "IOHW4o") ): if len(kernel.shape) == 4: - oc, _, kh, kw = get_const_tuple(kernel.shape) + oc, _, _, _ = get_const_tuple(kernel.shape) else: - oc, _, kh, kw, _ = get_const_tuple(kernel.shape) + oc, _, _, _, _ = get_const_tuple(kernel.shape) # We cannot use textures for case than number of channels is less than 4. # So, we use compute functions from cuda. if len(kernel.shape) == 4 and oc < 4: diff --git a/python/tvm/topi/adreno/conv2d_transpose_nchw.py b/python/tvm/topi/adreno/conv2d_transpose_nchw.py index 5d2bce483c3b..75abca761489 100644 --- a/python/tvm/topi/adreno/conv2d_transpose_nchw.py +++ b/python/tvm/topi/adreno/conv2d_transpose_nchw.py @@ -20,15 +20,13 @@ from tvm import te from tvm import autotvm from .. import nn -from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity + from ..utils import get_const_tuple, traverse_inline from .utils import ( split_to_chunks, pack_input, pack_filter, - expand_spatial_dimensions, - add_pad, bind_data_copy, get_default_conv2d_config, get_texture_storage, From c5a2d58405b52f10a46c5452893f2504a0c2f67a Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Fri, 29 Sep 2023 11:39:59 +0530 Subject: [PATCH 06/18] Add more conv2dTranspose testcases --- .../test_conv2d_transpose_nchw_texture.py | 113 ++++++++++++------ tests/scripts/task_config_build_adreno.sh | 1 + 2 files changed, 80 insertions(+), 34 deletions(-) diff --git a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py index 6f2e15942fc0..ed859b1016ea 100644 --- a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py +++ b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py @@ -31,43 +31,88 @@ @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): - input_shape = (1, 256, 100, 100) - filter_shape = (256, 64, 4, 4) - channels = 64 - kernel_size = (4, 4) - strides = (2, 2) - padding = (1, 1, 1, 1) - x = relay.var("data", shape=input_shape, dtype=dtype) - w = relay.var("weight", shape=filter_shape, dtype=dtype) - y = relay.nn.conv2d_transpose( - x, - w, - channels=channels, - kernel_size=kernel_size, - strides=strides, - padding=padding, - kernel_layout="IOHW", - data_layout="NCHW", - ) + trials = [ + [4, 4, (1, 1), (2, 2), (1, 1), 64, (256, 100, 100), (False, False)], + [4, 4, (0, 0), (2, 2), (1, 1), 256, (32, 64, 64), (False, False)], + [3, 3, (0, 0), (2, 2), (1, 1), 64, (256, 12, 12), (True, True)], + [4, 4, (1, 1), (1, 1), (1, 1), 512, (16, 100, 100), (False, False)], + [5, 5, (2, 2), (2, 2), (1, 1), 4, (16, 100, 100), (True, False)], + [7, 7, (3, 3), (2, 2), (1, 1), 8, (4, 100, 100), (False, True)], + ] - mod = relay.Function([x, w], y) - np.random.seed(0) - initializer = relay.testing.init.Xavier() - filter_data = np.zeros(filter_shape).astype(dtype) - initializer("weight", filter_data) - params1 = { - "weight": tvm.nd.array(filter_data), - } - - if executor_type == "ge": - build_run_compare( - remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, [], gpu_preprocess - ) - else: - build_run_compare_vm( - remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, [], gpu_preprocess + for ( + kernel_h, + kernel_w, + pad, + stride, + dilation, + out_channels, + shape, + composite, + ) in trials: + shape = (1, *shape) + has_bias = composite[0] + has_activation = composite[1] + input_shape = shape + filter_shape = (shape[1], out_channels, kernel_w, kernel_h) + x = relay.var("data", shape=input_shape, dtype=dtype) + w = relay.var("weight", shape=filter_shape, dtype=dtype) + inputs = [x, w] + y = relay.nn.conv2d_transpose( + x, + w, + channels=out_channels, + kernel_size=(kernel_w, kernel_h), + strides=stride, + padding=pad, + kernel_layout="IOHW", + data_layout="NCHW", + dilation=dilation, ) + np.random.seed(0) + initializer = relay.testing.init.Xavier() + filter_data = np.zeros(filter_shape).astype(dtype) + initializer("weight", filter_data) + params1 = { + "weight": tvm.nd.array(filter_data), + } + + if has_bias: + b = relay.var("bias", shape=(out_channels,), dtype=dtype) + y = relay.nn.bias_add(y, b, axis=1) + inputs.append(b) + bias_data = np.zeros((out_channels,)).astype(dtype) + initializer("bias", bias_data) + params1["bias"] = tvm.nd.array(bias_data) + + if has_activation: + y = relay.nn.relu(y) + + mod = relay.Function(inputs, y) + + if executor_type == "ge": + build_run_compare( + remote, + mod, + params1, + {"data": input_shape}, + {"data": dtype}, + target, + [], + gpu_preprocess, + ) + else: + build_run_compare_vm( + remote, + mod, + params1, + {"data": input_shape}, + {"data": dtype}, + target, + [], + gpu_preprocess, + ) if __name__ == "__main__": diff --git a/tests/scripts/task_config_build_adreno.sh b/tests/scripts/task_config_build_adreno.sh index 62e6ffecbced..25b856f52421 100755 --- a/tests/scripts/task_config_build_adreno.sh +++ b/tests/scripts/task_config_build_adreno.sh @@ -30,3 +30,4 @@ echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake echo set\(USE_LLVM ON\) >> config.cmake +echo set\(USE_OPENCL ON\) >> config.cmake From 61b7f6eb1a5b0cccb3f1530a96174fded7fcdee2 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Mon, 9 Oct 2023 18:37:52 +0530 Subject: [PATCH 07/18] empty update empty update for retrigger ci From cc49fc473e28bf2549dcc6543eb5b352094b0391 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Mon, 9 Oct 2023 18:40:17 +0530 Subject: [PATCH 08/18] Update test_conv2d_transpose_nchw_texture.py --- .../relay/opencl_texture/test_conv2d_transpose_nchw_texture.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py index ed859b1016ea..688f457414a5 100644 --- a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py +++ b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py @@ -31,7 +31,7 @@ @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): - + # Conv2d transpose test cases lists trials = [ [4, 4, (1, 1), (2, 2), (1, 1), 64, (256, 100, 100), (False, False)], [4, 4, (0, 0), (2, 2), (1, 1), 256, (32, 64, 64), (False, False)], From 7ccfd8ca68c82f29d01be592b0fdff207989829a Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Mon, 16 Oct 2023 15:35:49 +0530 Subject: [PATCH 09/18] Added more testcase to check memory scopes --- .../test_conv2d_transpose_nchw_texture.py | 106 +++++++++++++++++- 1 file changed, 102 insertions(+), 4 deletions(-) diff --git a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py index 688f457414a5..ce79b820ac21 100644 --- a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py +++ b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py @@ -23,8 +23,9 @@ from tvm.contrib import utils from utils.adreno_utils import gpu_preprocess, build_run_compare, build_run_compare_vm import pytest +from conftest import remote -executor_type = tvm.testing.parameter("ge") +executor_type = tvm.testing.parameter("ge", "vm") dtype = tvm.testing.parameter("float32") @@ -59,8 +60,12 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): x = relay.var("data", shape=input_shape, dtype=dtype) w = relay.var("weight", shape=filter_shape, dtype=dtype) inputs = [x, w] + W1 = relay.var("weight1", shape=(shape[1], shape[1], 1, 1), dtype=dtype) + conv = relay.nn.conv2d(x, W1, padding=[0, 0, 0, 0], channels=shape[1], kernel_size=(1, 1)) + inputs.append(W1) + conv = relay.op.nn.relu(conv) y = relay.nn.conv2d_transpose( - x, + conv, w, channels=out_channels, kernel_size=(kernel_w, kernel_h), @@ -70,6 +75,7 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): data_layout="NCHW", dilation=dilation, ) + np.random.seed(0) initializer = relay.testing.init.Xavier() filter_data = np.zeros(filter_shape).astype(dtype) @@ -78,6 +84,82 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): "weight": tvm.nd.array(filter_data), } + if has_bias: + b = relay.var("bias", shape=(out_channels,), dtype=dtype) + y = relay.nn.bias_add(y, b, axis=1) + inputs.append(b) + bias_data = np.zeros((out_channels,)).astype(dtype) + initializer("bias", bias_data) + params1["bias"] = tvm.nd.array(bias_data) + if has_activation: + y = relay.nn.relu(y) + + mod = relay.Function(inputs, out) + if executor_type == "ge": + build_run_compare( + remote, + mod, + params1, + {"data": input_shape}, + {"data": dtype}, + target, + [], + gpu_preprocess, + ) + else: + build_run_compare_vm( + remote, + mod, + params1, + {"data": input_shape}, + {"data": dtype}, + target, + [], + gpu_preprocess, + ) + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_conv2d_transpose_three_layer_block(remote, target, executor_type, dtype): + # Conv2d transpose test cases lists + trials = [ + [4, 4, (1, 1), (2, 2), (1, 1), 64, (256, 100, 100), (False, False)], + [3, 3, (0, 0), (2, 2), (1, 1), 64, (256, 12, 12), (True, True)], + ] + + for ( + kernel_h, + kernel_w, + pad, + stride, + dilation, + out_channels, + shape, + composite, + ) in trials: + shape = (1, *shape) + has_bias = composite[0] + has_activation = composite[1] + input_shape = shape + filter_shape = (shape[1], out_channels, kernel_w, kernel_h) + x = relay.var("data", shape=input_shape, dtype=dtype) + w = relay.var("weight", shape=filter_shape, dtype=dtype) + inputs = [x, w] + W1 = relay.var("weight1", shape=(shape[1], shape[1], 1, 1), dtype=dtype) + conv = relay.nn.conv2d(x, W1, padding=[0, 0, 0, 0], channels=shape[1], kernel_size=(1, 1)) + inputs.append(W1) + conv = relay.op.nn.relu(conv) + y = relay.nn.conv2d_transpose( + conv, + w, + channels=out_channels, + kernel_size=(kernel_w, kernel_h), + strides=stride, + padding=pad, + kernel_layout="IOHW", + data_layout="NCHW", + dilation=dilation, + ) if has_bias: b = relay.var("bias", shape=(out_channels,), dtype=dtype) y = relay.nn.bias_add(y, b, axis=1) @@ -88,8 +170,25 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): if has_activation: y = relay.nn.relu(y) + W2 = relay.var("weight2", shape=(out_channels, out_channels, 1, 1), dtype=dtype) + out = relay.nn.conv2d(y, W2, padding=[0, 0, 0, 0], channels=out_channels, kernel_size=(1, 1)) + out = relay.op.nn.relu(out) + np.random.seed(0) + inputs.append(W2) + initializer = relay.testing.init.Xavier() + filter_data = np.zeros(filter_shape).astype(dtype) + initializer("weight", filter_data) + filter_data1 = np.zeros((shape[1], shape[1], 1, 1)).astype(dtype) + initializer("weight", filter_data1) + filter_data2 = np.zeros((out_channels, out_channels, 1, 1)).astype(dtype) + initializer("weight", filter_data2) + params1 = { + "weight": tvm.nd.array(filter_data), + "weight1": tvm.nd.array(filter_data1), + "weight2": tvm.nd.array(filter_data2), + } - mod = relay.Function(inputs, y) + mod = relay.Function(inputs, out) if executor_type == "ge": build_run_compare( @@ -114,6 +213,5 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): gpu_preprocess, ) - if __name__ == "__main__": tvm.testing.main() From a438d2090b9cf1470f8eefa57914de13605b5830 Mon Sep 17 00:00:00 2001 From: Siva Date: Tue, 17 Oct 2023 14:02:02 +0530 Subject: [PATCH 10/18] Device specific alter_op_layout for conv2d_transpose --- python/tvm/relay/op/nn/_nn.py | 6 + python/tvm/topi/adreno/__init__.py | 1 + .../topi/adreno/conv2d_transpose_alter_op.py | 123 ++++++++++++++++++ python/tvm/topi/nn/conv2d.py | 24 ++++ 4 files changed, 154 insertions(+) create mode 100644 python/tvm/topi/adreno/conv2d_transpose_alter_op.py diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index c68685f0ae09..6acaf43fe7d2 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -335,6 +335,12 @@ def legalize_conv2d_transpose(attrs, inputs, types): return topi.nn.conv2d_transpose_legalize(attrs, inputs, types) +@reg.register_alter_op_layout("nn.conv2d_transpose") +def alter_op_layout_conv2d_transpose(attrs, inputs, tinfos, out_type): + """Alternate the layout of conv2d_transpose""" + return topi.nn.conv2d_transpose_alter_layout(attrs, inputs, tinfos, out_type) + + @reg.register_convert_op_layout("nn.conv2d_transpose") def convert_conv2d_transpose(attrs, inputs, tinfos, desired_layouts): """Convert Layout pass registration for conv2d_transpose op. diff --git a/python/tvm/topi/adreno/__init__.py b/python/tvm/topi/adreno/__init__.py index f54c6f9bd6fe..cd42848b29b3 100644 --- a/python/tvm/topi/adreno/__init__.py +++ b/python/tvm/topi/adreno/__init__.py @@ -23,6 +23,7 @@ from .depthwise_conv2d_nhwc import * from .pooling import * from .conv2d_alter_op import * +from .conv2d_transpose_alter_op import * from .conv2d_nchw_winograd import * from .conv2d_nhwc_winograd import * from .injective import schedule_injective diff --git a/python/tvm/topi/adreno/conv2d_transpose_alter_op.py b/python/tvm/topi/adreno/conv2d_transpose_alter_op.py new file mode 100644 index 000000000000..9c00962d60f0 --- /dev/null +++ b/python/tvm/topi/adreno/conv2d_transpose_alter_op.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +"""Conv2D Transpose alter op for Qualcomm Adreno GPU""" + +import logging + +import re +import tvm +from tvm import te +from tvm import relay +from tvm import autotvm +from ..utils import get_const_tuple +from .utils import infer_tile_size +from ..nn import conv2d_alter_layout +from ..nn import conv2d_transpose_alter_layout + +logger = logging.getLogger("topi") + +# Number of wildcards for matching of supported layouts to be transformed +_NCHWc_matcher = re.compile("^NCHW[0-9]+c$") +_IOHWo_matcher = re.compile("^IOHW[0-9]+o$") + + +@conv2d_transpose_alter_layout.register("adreno") +def _alter_conv2d_transpose_layout(attrs, inputs, tinfos, out_type): + """ + Prepare of the new conv2d_transpose with proper target blocked layout attributes + OpenCL Textures supports 1d/2d/3d/4d tetures but read happens always only for 4 elements + in a line. Thus way we are supporting for now only 4d conversions on the end + NCHW -> NCHW4c & IOHW ->IOHW4o + """ + target = tvm.target.Target.current(allow_none=False) + dispatch_ctx = autotvm.task.DispatchContext.current + new_attrs = {k: attrs[k] for k in attrs.keys()} + + # Parse the attributes. + padding = attrs.get_int_tuple("padding") + strides = attrs.get_int_tuple("strides") + dilation = attrs.get_int_tuple("dilation") + data_layout = attrs["data_layout"] + kernel_layout = attrs["kernel_layout"] + data_tensor, kernel_tensor = tinfos + data_dtype = data_tensor.dtype + out_dtype = out_type.dtype + + if isinstance(dispatch_ctx, autotvm.task.ApplyGraphBest): + cfg = dispatch_ctx.query(target, None) + workload = cfg.workload + else: + impl, outs = relay.backend.te_compiler.select_implementation( + relay.op.get("nn.conv2d_transpose"), attrs, tinfos, out_type, target + ) + workload = autotvm.task.get_workload(outs) + cfg = dispatch_ctx.query(target, workload) + + topi_tmpl = workload[0] + + if "conv2d_transpose_nchwc" in topi_tmpl: # covers conv2d_transpose_nchwc + if data_layout == "NCHW" and kernel_layout == "IOHW": + batch, in_channels, in_height, in_width = data_tensor.shape + _, out_channles, kernel_h, kernel_w = kernel_tensor.shape + in_channel_block = in_channels % 4 + if in_channel_block == 0: + in_channel_block = 4 + num_filter_block = out_channles % 4 + if num_filter_block == 0: + num_filter_block = 4 + + # no support yet for tensors that cannot be divisible by factor 4 + if num_filter_block != 4: + return None + + batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) + in_filter_channel, out_channel, kh, kw = get_const_tuple(kernel_tensor.shape) + + # update new attrs + new_attrs["channels"] = out_channel + if in_channel_block == 4: + new_attrs["data_layout"] = f"NCHW{in_channel_block}c" + else: + new_attrs["data_layout"] = "NCHW" + # (oc, ic, h, w) -> (ic, OC, h, w, oc) + new_attrs["kernel_layout"] = f"IOHW{num_filter_block}o" + new_attrs["out_layout"] = f"NCHW{num_filter_block}c" + + # Store altered operator's config for applying of tuned AutoTVM statistics + if in_channel_block == 4: + new_data = te.placeholder( + (batch_size, in_channel // in_channel_block, height, width, in_channel_block), + dtype=data_dtype, + ) + else: + new_data = data_tensor + new_kernel = te.placeholder( + (in_filter_channel, out_channel // num_filter_block, kh, kw, num_filter_block), + dtype=kernel_tensor.dtype, + ) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, out_dtype], + topi_tmpl, # "conv2d_transpose_nchwc.image2d", + ) + dispatch_ctx.update(target, new_workload, cfg) + else: + assert _NCHWc_matcher.match(data_layout) + assert _IOHWo_matcher.match(kernel_layout) + return relay.nn.conv2d_transpose(*inputs, **new_attrs) + + return None diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index f70d749e0f3c..792bdb7bb7b3 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -143,6 +143,30 @@ def conv2d_alter_layout(attrs, inputs, tinfos, out_type): return None +@tvm.target.generic_func +def conv2d_transpose_alter_layout(attrs, inputs, tinfos, out_type): + """Change Conv2D_Transpose layout. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : tvm.relay.Expr + Grouped input symbols + tinfos : list + Input shape and dtype + out_type: type + The output type + + Note + ---- + Unlike other TOPI functions, this function operates on both graph level and operator level. + """ + print("Transpose conv alter op layout called") + # not to change by default + return None + + @tvm.target.generic_func def conv2d_infer_layout(workload, cfg): """Infer input/output shapes and layouts from a workload and cfg. From 6c32b1bee809c0454a6143ae9b31f724df6bdf32 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Thu, 26 Oct 2023 17:53:38 +0530 Subject: [PATCH 11/18] Fix in virtual device setup and added test case with scope check --- .../topi/adreno/conv2d_transpose_alter_op.py | 2 - python/tvm/topi/nn/conv2d.py | 1 - .../transforms/annotate_texture_storage.cc | 4 + .../test_conv2d_transpose_nchw_texture.py | 142 +++++++++++++++--- 4 files changed, 124 insertions(+), 25 deletions(-) diff --git a/python/tvm/topi/adreno/conv2d_transpose_alter_op.py b/python/tvm/topi/adreno/conv2d_transpose_alter_op.py index 9c00962d60f0..c68e5cb7a558 100644 --- a/python/tvm/topi/adreno/conv2d_transpose_alter_op.py +++ b/python/tvm/topi/adreno/conv2d_transpose_alter_op.py @@ -25,8 +25,6 @@ from tvm import relay from tvm import autotvm from ..utils import get_const_tuple -from .utils import infer_tile_size -from ..nn import conv2d_alter_layout from ..nn import conv2d_transpose_alter_layout logger = logging.getLogger("topi") diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 792bdb7bb7b3..1b23fa542dd1 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -162,7 +162,6 @@ def conv2d_transpose_alter_layout(attrs, inputs, tinfos, out_type): ---- Unlike other TOPI functions, this function operates on both graph level and operator level. """ - print("Transpose conv alter op layout called") # not to change by default return None diff --git a/src/relay/transforms/annotate_texture_storage.cc b/src/relay/transforms/annotate_texture_storage.cc index 01d47b69530b..9ccb2171d8e9 100644 --- a/src/relay/transforms/annotate_texture_storage.cc +++ b/src/relay/transforms/annotate_texture_storage.cc @@ -392,6 +392,10 @@ class StorageInfo : private transform::DeviceAwareExprVisitor { (attrs->kernel_layout == "OIHW4o" || attrs->kernel_layout == "HWIO4o")) { supports_texture_storage = true; } + } else if (auto attrs = call->attrs.as()) { + if (attrs->data_layout == "NCHW4c" && attrs->kernel_layout == "IOHW4o") { + supports_texture_storage = true; + } } else if (auto attrs = call->attrs.as()) { if (attrs->layout == "NCHW4c") { supports_texture_storage = true; diff --git a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py index ce79b820ac21..f1986ca77a00 100644 --- a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py +++ b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py @@ -23,10 +23,10 @@ from tvm.contrib import utils from utils.adreno_utils import gpu_preprocess, build_run_compare, build_run_compare_vm import pytest -from conftest import remote + executor_type = tvm.testing.parameter("ge", "vm") -dtype = tvm.testing.parameter("float32") +dtype = tvm.testing.parameter("float32", "float16") @tvm.testing.requires_opencl @@ -36,13 +36,61 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): trials = [ [4, 4, (1, 1), (2, 2), (1, 1), 64, (256, 100, 100), (False, False)], [4, 4, (0, 0), (2, 2), (1, 1), 256, (32, 64, 64), (False, False)], - [3, 3, (0, 0), (2, 2), (1, 1), 64, (256, 12, 12), (True, True)], + [3, 3, (0, 0), (2, 2), (1, 1), 64, (256, 100, 100), (True, True)], [4, 4, (1, 1), (1, 1), (1, 1), 512, (16, 100, 100), (False, False)], [5, 5, (2, 2), (2, 2), (1, 1), 4, (16, 100, 100), (True, False)], [7, 7, (3, 3), (2, 2), (1, 1), 8, (4, 100, 100), (False, True)], ] + ge_texture_scopes = [ + ["", "global.texture", "global.texture-weight", "", ""], + ["", "global.texture", "global.texture-weight", "", ""], + ["", "global.texture", "global.texture-weight", "global.texture-weight", "", ""], + ["", "global.texture", "global.texture-weight", "", ""], + ["", "global.texture", "global.texture-weight", "global.texture-weight", "", ""], + ["", "global.texture", "global.texture-nhwc", "", ""], + ] + vm_texture_scopes = [ + """ + VM VirtualDevice[0]: device type 1, id 0 and mem_scope + VM VirtualDevice[1]: device type 4, id 0 and mem_scope + VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture + VM VirtualDevice[3]: device type 4, id 0 and mem_scope global.texture-weight + """, + """ + VM VirtualDevice[0]: device type 1, id 0 and mem_scope + VM VirtualDevice[1]: device type 4, id 0 and mem_scope + VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture + VM VirtualDevice[3]: device type 4, id 0 and mem_scope global.texture-weight + """, + """ + VM VirtualDevice[0]: device type 1, id 0 and mem_scope + VM VirtualDevice[1]: device type 4, id 0 and mem_scope + VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture + VM VirtualDevice[3]: device type 4, id 0 and mem_scope global.texture-weight + VM VirtualDevice[4]: device type 4, id 0 and mem_scope global.texture-weight + """, + """ + VM VirtualDevice[0]: device type 1, id 0 and mem_scope + VM VirtualDevice[1]: device type 4, id 0 and mem_scope + VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture + VM VirtualDevice[3]: device type 4, id 0 and mem_scope global.texture-weight + """, + """ + VM VirtualDevice[0]: device type 1, id 0 and mem_scope + VM VirtualDevice[1]: device type 4, id 0 and mem_scope + VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture + VM VirtualDevice[3]: device type 4, id 0 and mem_scope global.texture-weight + VM VirtualDevice[4]: device type 4, id 0 and mem_scope global.texture-weight + """, + """ + VM VirtualDevice[0]: device type 1, id 0 and mem_scope + VM VirtualDevice[1]: device type 4, id 0 and mem_scope + VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture + VM VirtualDevice[3]: device type 4, id 0 and mem_scope global.texture-nhwc + """, + ] - for ( + for i, ( kernel_h, kernel_w, pad, @@ -51,7 +99,7 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): out_channels, shape, composite, - ) in trials: + ) in enumerate(trials): shape = (1, *shape) has_bias = composite[0] has_activation = composite[1] @@ -60,12 +108,8 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): x = relay.var("data", shape=input_shape, dtype=dtype) w = relay.var("weight", shape=filter_shape, dtype=dtype) inputs = [x, w] - W1 = relay.var("weight1", shape=(shape[1], shape[1], 1, 1), dtype=dtype) - conv = relay.nn.conv2d(x, W1, padding=[0, 0, 0, 0], channels=shape[1], kernel_size=(1, 1)) - inputs.append(W1) - conv = relay.op.nn.relu(conv) y = relay.nn.conv2d_transpose( - conv, + x, w, channels=out_channels, kernel_size=(kernel_w, kernel_h), @@ -94,7 +138,7 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): if has_activation: y = relay.nn.relu(y) - mod = relay.Function(inputs, out) + mod = relay.Function(inputs, y) if executor_type == "ge": build_run_compare( remote, @@ -103,7 +147,7 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): {"data": input_shape}, {"data": dtype}, target, - [], + ge_texture_scopes[i], gpu_preprocess, ) else: @@ -114,20 +158,69 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): {"data": input_shape}, {"data": dtype}, target, - [], + vm_texture_scopes[i], gpu_preprocess, ) + @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") def test_conv2d_transpose_three_layer_block(remote, target, executor_type, dtype): # Conv2d transpose test cases lists trials = [ [4, 4, (1, 1), (2, 2), (1, 1), 64, (256, 100, 100), (False, False)], - [3, 3, (0, 0), (2, 2), (1, 1), 64, (256, 12, 12), (True, True)], + [3, 3, (0, 0), (1, 1), (1, 1), 64, (256, 12, 12), (True, True)], + ] + ge_texture_scopes = [ + [ + "", + "global.texture", + "global.texture-weight", + "global.texture", + "global.texture-weight", + "global.texture", + "global.texture-weight", + "", + "", + ], + [ + "", + "global.texture-nhwc", + "global.texture-weight", + "global.texture-nhwc", + "global.texture-weight", + "global.texture-weight", + "global.texture-nhwc", + "global.texture-weight", + "", + "", + ], + ] + vm_texture_scopes = [ + """ + VM VirtualDevice[0]: device type 1, id 0 and mem_scope + VM VirtualDevice[1]: device type 4, id 0 and mem_scope + VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture + VM VirtualDevice[3]: device type 4, id 0 and mem_scope global.texture + VM VirtualDevice[4]: device type 4, id 0 and mem_scope global.texture-weight + VM VirtualDevice[5]: device type 4, id 0 and mem_scope global.texture + VM VirtualDevice[6]: device type 4, id 0 and mem_scope global.texture-weight + VM VirtualDevice[7]: device type 4, id 0 and mem_scope global.texture-weight + """, + """ + VM VirtualDevice[0]: device type 1, id 0 and mem_scope + VM VirtualDevice[1]: device type 4, id 0 and mem_scope + VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture-nhwc + VM VirtualDevice[3]: device type 4, id 0 and mem_scope global.texture-nhwc + VM VirtualDevice[4]: device type 4, id 0 and mem_scope global.texture-weight + VM VirtualDevice[5]: device type 4, id 0 and mem_scope global.texture-nhwc + VM VirtualDevice[6]: device type 4, id 0 and mem_scope global.texture-weight + VM VirtualDevice[7]: device type 4, id 0 and mem_scope global.texture-weight + VM VirtualDevice[8]: device type 4, id 0 and mem_scope global.texture-weight + """, ] - for ( + for i, ( kernel_h, kernel_w, pad, @@ -136,7 +229,7 @@ def test_conv2d_transpose_three_layer_block(remote, target, executor_type, dtype out_channels, shape, composite, - ) in trials: + ) in enumerate(trials): shape = (1, *shape) has_bias = composite[0] has_activation = composite[1] @@ -160,18 +253,18 @@ def test_conv2d_transpose_three_layer_block(remote, target, executor_type, dtype data_layout="NCHW", dilation=dilation, ) + if has_bias: b = relay.var("bias", shape=(out_channels,), dtype=dtype) y = relay.nn.bias_add(y, b, axis=1) inputs.append(b) - bias_data = np.zeros((out_channels,)).astype(dtype) - initializer("bias", bias_data) - params1["bias"] = tvm.nd.array(bias_data) if has_activation: y = relay.nn.relu(y) W2 = relay.var("weight2", shape=(out_channels, out_channels, 1, 1), dtype=dtype) - out = relay.nn.conv2d(y, W2, padding=[0, 0, 0, 0], channels=out_channels, kernel_size=(1, 1)) + out = relay.nn.conv2d( + y, W2, padding=[0, 0, 0, 0], channels=out_channels, kernel_size=(1, 1) + ) out = relay.op.nn.relu(out) np.random.seed(0) inputs.append(W2) @@ -187,6 +280,10 @@ def test_conv2d_transpose_three_layer_block(remote, target, executor_type, dtype "weight1": tvm.nd.array(filter_data1), "weight2": tvm.nd.array(filter_data2), } + if has_bias: + bias_data = np.zeros((out_channels,)).astype(dtype) + initializer("bias", bias_data) + params1["bias"] = tvm.nd.array(bias_data) mod = relay.Function(inputs, out) @@ -198,7 +295,7 @@ def test_conv2d_transpose_three_layer_block(remote, target, executor_type, dtype {"data": input_shape}, {"data": dtype}, target, - [], + ge_texture_scopes[i], gpu_preprocess, ) else: @@ -209,9 +306,10 @@ def test_conv2d_transpose_three_layer_block(remote, target, executor_type, dtype {"data": input_shape}, {"data": dtype}, target, - [], + vm_texture_scopes[i], gpu_preprocess, ) + if __name__ == "__main__": tvm.testing.main() From b2dc7b036e60a533c50c8e72306070a5b3a2ba5d Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Mon, 30 Oct 2023 10:17:37 +0530 Subject: [PATCH 12/18] Add the comment conv2d algo --- python/tvm/topi/adreno/conv2d_transpose_nchw.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/tvm/topi/adreno/conv2d_transpose_nchw.py b/python/tvm/topi/adreno/conv2d_transpose_nchw.py index 75abca761489..301e8fc5b5f5 100644 --- a/python/tvm/topi/adreno/conv2d_transpose_nchw.py +++ b/python/tvm/topi/adreno/conv2d_transpose_nchw.py @@ -40,6 +40,18 @@ def conv2d_transpose_nchwc( """ Transposed Convolution operator in NCHWc layout. Algo: + 1. Convert into blocked format if we have 4d original tensor. + In case of AutoTVM we override the convert by just tensors since such conversion + will be absent for real blocked convolution, no sense to include into tuning + 2. Expand spatial dimensions to have width and height be dividable by factor 4 + This leads to slightly bigger amount of compute but allow utilize GPU much better + 3. Add paddings. This happens even if we do not need pad originaly. This is useful + due to work arounding of the gaps of texture annotation between Primary Functions + and limited support of textures in schedules. Later on this pad will be executed + separately and will produce texture + 4. 5d Convolution compute with accumulating into out_dtype + 5. Cast to the origin output data type + 6. For case of 4d convolution: convert of output from 5d to 4d """ if out_dtype is None: From 4dd6efdf161b4b4f539386450cf4e321884c1b56 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Mon, 30 Oct 2023 10:36:27 +0530 Subject: [PATCH 13/18] Add the comment conv2d algo --- python/tvm/topi/adreno/conv2d_transpose_nchw.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/python/tvm/topi/adreno/conv2d_transpose_nchw.py b/python/tvm/topi/adreno/conv2d_transpose_nchw.py index 301e8fc5b5f5..ad8c7b88ef50 100644 --- a/python/tvm/topi/adreno/conv2d_transpose_nchw.py +++ b/python/tvm/topi/adreno/conv2d_transpose_nchw.py @@ -231,6 +231,23 @@ def schedule_conv2d_transpose_NCHWc(cfg, s, output): schedule optimized for batch size = 1 Algo: + 1. Split output axis to three parts: global work size, vthread, local worksize. + The limitations for tuning includes heuristics from some tuned networks to limit + search space and not pay much time for useles configurations. + 2. In case of 4d convolution schedule copying of the input (and filter) into + 5d tensors + 4. pad should be scheduled separately to create independent opencl kernel. If pad is + inlined into convolution, this gives 1.5x performance drop + 5. We are using cache_read for intermediate tensors to produce texture and guarantee + the best performance on the next stage. + The weights are managed through static texture planning mechanism and guarantied come + in texture memory scope. + Thus way we are calling cache_read only for data tensor + 6. For 5d convolution we schedule the latest op with binding 5d axis and vectorize + for textures + For 4d tensor we are doing the same for the latest blocked stage, i.e. conversion + of data type + 7. In case of 4d conv we need to schedule postops as well """ latest = s.outputs[0].output(0) if len(latest.op.axis) == 4: From adc75b802221c9baa892707cea76479246828d2b Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Mon, 30 Oct 2023 16:27:51 +0530 Subject: [PATCH 14/18] Removed fp16 test case from texture It is failing for few gpu devices. --- .../relay/opencl_texture/test_conv2d_transpose_nchw_texture.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py index f1986ca77a00..241076e92302 100644 --- a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py +++ b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py @@ -26,7 +26,7 @@ executor_type = tvm.testing.parameter("ge", "vm") -dtype = tvm.testing.parameter("float32", "float16") +dtype = tvm.testing.parameter("float32") @tvm.testing.requires_opencl From da87c615344bac0929ddf19c2d055aa0e6567a2d Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Tue, 31 Oct 2023 17:46:43 +0530 Subject: [PATCH 15/18] remove opencl config change for mainline confilct --- tests/scripts/task_config_build_adreno.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/scripts/task_config_build_adreno.sh b/tests/scripts/task_config_build_adreno.sh index 25b856f52421..62e6ffecbced 100755 --- a/tests/scripts/task_config_build_adreno.sh +++ b/tests/scripts/task_config_build_adreno.sh @@ -30,4 +30,3 @@ echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake echo set\(USE_LLVM ON\) >> config.cmake -echo set\(USE_OPENCL ON\) >> config.cmake From c1b4d5ff5f534dcb967bef424aab33389913e412 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Mon, 6 Nov 2023 14:28:12 +0530 Subject: [PATCH 16/18] Add the test case for 3 channel input which run with cuda schecule --- .../test_conv2d_transpose_nchw_texture.py | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py index 241076e92302..735fd474cfcd 100644 --- a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py +++ b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py @@ -34,12 +34,13 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): # Conv2d transpose test cases lists trials = [ - [4, 4, (1, 1), (2, 2), (1, 1), 64, (256, 100, 100), (False, False)], - [4, 4, (0, 0), (2, 2), (1, 1), 256, (32, 64, 64), (False, False)], - [3, 3, (0, 0), (2, 2), (1, 1), 64, (256, 100, 100), (True, True)], - [4, 4, (1, 1), (1, 1), (1, 1), 512, (16, 100, 100), (False, False)], - [5, 5, (2, 2), (2, 2), (1, 1), 4, (16, 100, 100), (True, False)], - [7, 7, (3, 3), (2, 2), (1, 1), 8, (4, 100, 100), (False, True)], + [4, 4, (1, 1), (2, 2), (1, 1), 64, (256, 100, 100), (False, False), gpu_preprocess], + [4, 4, (0, 0), (2, 2), (1, 1), 256, (32, 64, 64), (False, False), None], + [3, 3, (0, 0), (2, 2), (1, 1), 64, (256, 100, 100), (True, True), None], + [4, 4, (1, 1), (1, 1), (1, 1), 512, (16, 100, 100), (False, False), gpu_preprocess], + [5, 5, (2, 2), (2, 2), (1, 1), 4, (16, 100, 100), (True, False), gpu_preprocess], + [7, 7, (3, 3), (2, 2), (1, 1), 8, (4, 100, 100), (False, True), None], + [7, 7, (3, 3), (2, 2), (1, 1), 64, (3, 100, 100), (True, True), None], ] ge_texture_scopes = [ ["", "global.texture", "global.texture-weight", "", ""], @@ -48,6 +49,7 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): ["", "global.texture", "global.texture-weight", "", ""], ["", "global.texture", "global.texture-weight", "global.texture-weight", "", ""], ["", "global.texture", "global.texture-nhwc", "", ""], + [], ] vm_texture_scopes = [ """ @@ -88,6 +90,7 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture VM VirtualDevice[3]: device type 4, id 0 and mem_scope global.texture-nhwc """, + [], ] for i, ( @@ -99,6 +102,7 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): out_channels, shape, composite, + _gpu_preprocess, ) in enumerate(trials): shape = (1, *shape) has_bias = composite[0] @@ -148,7 +152,7 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): {"data": dtype}, target, ge_texture_scopes[i], - gpu_preprocess, + _gpu_preprocess, ) else: build_run_compare_vm( @@ -159,7 +163,7 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): {"data": dtype}, target, vm_texture_scopes[i], - gpu_preprocess, + _gpu_preprocess, ) @@ -168,8 +172,8 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): def test_conv2d_transpose_three_layer_block(remote, target, executor_type, dtype): # Conv2d transpose test cases lists trials = [ - [4, 4, (1, 1), (2, 2), (1, 1), 64, (256, 100, 100), (False, False)], - [3, 3, (0, 0), (1, 1), (1, 1), 64, (256, 12, 12), (True, True)], + [4, 4, (1, 1), (2, 2), (1, 1), 64, (256, 100, 100), (False, False), None], + [3, 3, (0, 0), (1, 1), (1, 1), 64, (256, 12, 12), (True, True), gpu_preprocess], ] ge_texture_scopes = [ [ @@ -229,6 +233,7 @@ def test_conv2d_transpose_three_layer_block(remote, target, executor_type, dtype out_channels, shape, composite, + _gpu_preprocess, ) in enumerate(trials): shape = (1, *shape) has_bias = composite[0] @@ -296,7 +301,7 @@ def test_conv2d_transpose_three_layer_block(remote, target, executor_type, dtype {"data": dtype}, target, ge_texture_scopes[i], - gpu_preprocess, + _gpu_preprocess, ) else: build_run_compare_vm( @@ -307,7 +312,7 @@ def test_conv2d_transpose_three_layer_block(remote, target, executor_type, dtype {"data": dtype}, target, vm_texture_scopes[i], - gpu_preprocess, + _gpu_preprocess, ) From b95aacd0deaf2177ace81ee6afef1ab0883d4a28 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Mon, 6 Nov 2023 14:49:30 +0530 Subject: [PATCH 17/18] Fix in op strategy for out channel 3 --- python/tvm/relay/op/strategy/adreno.py | 4 ++-- .../opencl_texture/test_conv2d_transpose_nchw_texture.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py index 205355b40f1d..bacace9ad4f6 100644 --- a/python/tvm/relay/op/strategy/adreno.py +++ b/python/tvm/relay/op/strategy/adreno.py @@ -232,9 +232,9 @@ def conv2d_transpose_strategy_adreno(attrs, inputs, out_type, target): or (data_layout == "NCHW" and kernel_layout == "IOHW4o") ): if len(kernel.shape) == 4: - oc, _, _, _ = get_const_tuple(kernel.shape) + _, oc, _, _ = get_const_tuple(kernel.shape) else: - oc, _, _, _, _ = get_const_tuple(kernel.shape) + _, oc, _, _, _ = get_const_tuple(kernel.shape) # We cannot use textures for case than number of channels is less than 4. # So, we use compute functions from cuda. if len(kernel.shape) == 4 and oc < 4: diff --git a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py index 735fd474cfcd..612e976b0c30 100644 --- a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py +++ b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py @@ -41,6 +41,7 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): [5, 5, (2, 2), (2, 2), (1, 1), 4, (16, 100, 100), (True, False), gpu_preprocess], [7, 7, (3, 3), (2, 2), (1, 1), 8, (4, 100, 100), (False, True), None], [7, 7, (3, 3), (2, 2), (1, 1), 64, (3, 100, 100), (True, True), None], + [3, 3, (1, 1), (1, 1), (1, 1), 3, (16, 8, 8), (True, True), None], ] ge_texture_scopes = [ ["", "global.texture", "global.texture-weight", "", ""], @@ -50,6 +51,7 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): ["", "global.texture", "global.texture-weight", "global.texture-weight", "", ""], ["", "global.texture", "global.texture-nhwc", "", ""], [], + [], ] vm_texture_scopes = [ """ @@ -91,6 +93,7 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): VM VirtualDevice[3]: device type 4, id 0 and mem_scope global.texture-nhwc """, [], + [], ] for i, ( From d64268b5d7bbe11a55d461e6459128a456b72976 Mon Sep 17 00:00:00 2001 From: krishnaraj36 Date: Mon, 6 Nov 2023 22:50:39 +0530 Subject: [PATCH 18/18] Comment in test case for memory scope --- .../relay/opencl_texture/test_conv2d_transpose_nchw_texture.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py index 612e976b0c30..d110c8329fd1 100644 --- a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py +++ b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py @@ -43,6 +43,7 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): [7, 7, (3, 3), (2, 2), (1, 1), 64, (3, 100, 100), (True, True), None], [3, 3, (1, 1), (1, 1), (1, 1), 3, (16, 8, 8), (True, True), None], ] + # Tensors memory scope with graph executor build ge_texture_scopes = [ ["", "global.texture", "global.texture-weight", "", ""], ["", "global.texture", "global.texture-weight", "", ""], @@ -53,6 +54,7 @@ def test_conv2d_transpose_adreno(remote, target, executor_type, dtype): [], [], ] + # Tensors memory scope with vm executor build vm_texture_scopes = [ """ VM VirtualDevice[0]: device type 1, id 0 and mem_scope