From 2740929073eeabdbc678aad11bbe2f43f80b98a9 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Thu, 9 Feb 2023 20:05:07 +0300 Subject: [PATCH] [Adreno] Extend pack_filter for HWIO layout --- python/tvm/topi/adreno/utils.py | 19 ++++++++ .../test_conv2d_nhwc_texture.py | 47 +++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/python/tvm/topi/adreno/utils.py b/python/tvm/topi/adreno/utils.py index 1a1cc747faac..9716a62fcc7e 100644 --- a/python/tvm/topi/adreno/utils.py +++ b/python/tvm/topi/adreno/utils.py @@ -237,6 +237,18 @@ def _reorder_weights_depthwise_hwoi(*indices): Filter[indices[0], indices[1], indices[2] * out_block + indices[4], indices[3]], ) + def _reorder_weights_depthwise_hwio(*indices): + conditionA = [] + conditionA.append(indices[3] == out_chunks - 1) + conditionA.append(indices[4] >= out_original_tail) + conditionAT = tvm.tir.all(*conditionA) + + return tvm.tir.if_then_else( + conditionAT, + pad_value, + Filter[indices[0], indices[1], indices[2], indices[3] * out_block + indices[4]], + ) + def _reorder_weights_oihw(*indices): conditionA = [] conditionA.append(indices[0] == out_chunks - 1) @@ -284,6 +296,13 @@ def _reorder_weights_hwio(*indices): name="filter_pack", tag="filter_pack", ) + elif layout == "HWIO": + reordered_filter = te.compute( + [kernel_h, kernel_w, in_filter_channels, out_chunks, out_block], + _reorder_weights_depthwise_hwio, + name="filter_pack", + tag="filter_pack", + ) else: assert False, "Adreno util function def pack_filter does not accept unknown layout" else: diff --git a/tests/python/relay/opencl_texture/test_conv2d_nhwc_texture.py b/tests/python/relay/opencl_texture/test_conv2d_nhwc_texture.py index 43979cc79a68..f2bfc91174b3 100644 --- a/tests/python/relay/opencl_texture/test_conv2d_nhwc_texture.py +++ b/tests/python/relay/opencl_texture/test_conv2d_nhwc_texture.py @@ -589,6 +589,53 @@ def test_conv2d_vgg16_winograd_4d(remote, target, dtype): assert len(matches) > 0 +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_conv2d_vgg16_winograd_4d_expand_spatial_dims(remote, target, dtype): + input_shape = (1, 28, 28, 1) + filter_shape = (3, 3, 1, 64) + bias_shape = (1, 1, 1, 64) + A = relay.var("data", shape=input_shape, dtype=dtype) + B = relay.var("weight", shape=filter_shape, dtype=dtype) + bias = relay.var("bias", shape=bias_shape, dtype=dtype) + + conv = relay.nn.conv2d( + A, + B, + data_layout="NHWC", + kernel_layout="HWIO", + padding=[0, 0, 0, 0], + kernel_size=[3, 3], + out_dtype=dtype, + ) + D = relay.op.add(conv, bias) + D = relay.op.nn.relu(D) + + mod = relay.Function([A, B, bias], D) + np.random.seed(0) + initializer = relay.testing.init.Xavier() + filter_data = np.zeros(filter_shape).astype(dtype) + bias_data = np.zeros(bias_shape).astype(dtype) + initializer("weight", filter_data) + initializer("bias", bias_data) + params1 = { + "weight": tvm.nd.array(filter_data), + "bias": tvm.nd.array(bias_data), + } + + temp = utils.tempdir() + stat_file = temp.relpath("stat.log") + with open(stat_file, "w") as f: + f.write( + f'{{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno -max_num_threads=256", "conv2d_nhwc_winograd.image2d", [["TENSOR", [1, 28, 28, 1], "{dtype}"], ["TENSOR", [3, 3, 1, 64], "{dtype}"], [1, 1], [0, 0, 0, 0], [1, 1], "{dtype}"], {{}}], "config": {{"index": 1591, "code_hash": null, "entity": [["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", "sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}}, "result": [[0.0037244], 0, 7.06374192237854, 1653898629.7427933], "version": 0.2, "tvm_version": "0.8.dev0"}}\n' + ) + graph = build_run_compare( + remote, mod, params1, {"data": input_shape}, {"data": dtype}, target, stat_file=stat_file + ) + matches = re.findall("winograd", graph) + assert len(matches) > 0 + + @tvm.testing.requires_opencl @tvm.testing.parametrize_targets("opencl -device=adreno") def test_conv2d_winograd_conv(remote, target, dtype):