diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 613b1d084701..525ee95f4117 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -382,33 +382,6 @@ std::string CodeGenOpenCL::CastTo(std::string value, DataType target) { return os.str(); } -void CodeGenOpenCL::VisitStmt_(const BufferStoreNode* op) { - if (auto call = op->value.as()) { - if (call->op.same_as(builtin::texture2d_load())) { - need_texture_ssa_ = false; - // If storing a texture load into a buffer, don't use an - // intermediate local unless the buffer allocation is a - // single element selected from the texture read. - auto it = allocation_size_.find(op->buffer->data.get()); - if (it != allocation_size_.end() && it->second == 1) { - need_texture_ssa_ = true; - } - } - } - CodeGenC::VisitStmt_(op); - need_texture_ssa_ = true; -} - -void CodeGenOpenCL::VisitExpr_(const CastNode* op, std::ostream& os) { - if (auto call = op->value.as()) { - if (call->op.same_as(builtin::texture2d_load())) { - need_texture_ssa_ = false; - } - } - CodeGenC::VisitExpr_(op, os); - need_texture_ssa_ = true; -} - void CodeGenOpenCL::VisitStmt_(const AllocateNode* op) { allocation_size_.insert({op->buffer_var.get(), op->ConstantAllocationSize() * op->dtype.lanes()}); CodeGenC::VisitStmt_(op); @@ -472,20 +445,15 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(op->args[2], ss); ss << ")))"; - // Only use local SSA if texture is not already being stored - if (need_texture_ssa_) { - std::string rhs = SSAGetID(ss.str(), op->dtype.with_lanes(4)); - if (op->args.back().as()) { - os << rhs; - } else { - os << "(("; - this->PrintType(op->dtype.with_lanes(1), os); - os << "*)&" << rhs << ")["; - this->PrintExpr(op->args.back(), os); - os << "]"; - } + std::string rhs = SSAGetID(ss.str(), op->dtype.with_lanes(4)); + if (op->args.back().as()) { + os << rhs; } else { - os << ss.str(); + os << "(("; + this->PrintType(op->dtype.with_lanes(1), os); + os << "*)&" << rhs << ")["; + this->PrintExpr(op->args.back(), os); + os << "]"; } } else if (op->op.same_as(builtin_call_extern_)) { auto func = Downcast(op->args[0]); diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index 05734b6a54eb..8b365f85d6e6 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -66,9 +66,7 @@ class CodeGenOpenCL final : public CodeGenC { void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const CastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) - void VisitStmt_(const BufferStoreNode* op) final; // NOLINT(*) // overload min and max to avoid ambiguous call errors void VisitExpr_(const MinNode* op, std::ostream& os) final; @@ -86,9 +84,6 @@ class CodeGenOpenCL final : public CodeGenC { // Whether to enable sampler or sampler-less texture reads, // where the choice depends on the OpenCL version used. bool enable_compliant_texture_reads_{false}; - // Key to disable use of texture SSA in certain scenarios. For example, - // when loaded value is stored directly to a user declared l-value buffer - bool need_texture_ssa_{true}; // Mapping from buffer to allocation size. // Useful to track when a scalar store of a vectorized texture load is required. std::unordered_map allocation_size_; diff --git a/tests/python/relay/opencl_texture/test_injection_texture.py b/tests/python/relay/opencl_texture/test_injection_texture.py new file mode 100644 index 000000000000..991983706fcb --- /dev/null +++ b/tests/python/relay/opencl_texture/test_injection_texture.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re +import pytest +import tvm +import numpy as np +from tvm import relay +from tvm.relay import testing +from tvm.contrib import utils +from utils.adreno_utils import gpu_preprocess, build_run_compare + + +dtype = tvm.testing.parameter("float32") + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_layout_transform_to_block_nchw4c(remote, target, dtype): + """Verification of the case NCHW->NCHW4c""" + input_shape = (1, 32, 720, 1280) + A = relay.var("data", shape=input_shape, dtype=dtype) + lt = relay.layout_transform(A, "NCHW", "NCHW4c") + mod = relay.Function([A], lt) + + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_layout_transform_to_block_nchw(remote, target, dtype): + """Verification of the case NCHW4c->NCHW""" + input_shape = (1, 36, 1, 1, 4) + A = relay.var("data", shape=input_shape, dtype=dtype) + lt = relay.layout_transform(A, "NCHW4c", "NCHW") + mod = relay.Function([A], lt) + + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_layout_transform_to_block_nhwc4c(remote, target, dtype): + """Verification of the case NHWC->NHWC4c""" + input_shape = (1, 1, 1, 144) + A = relay.var("data", shape=input_shape, dtype=dtype) + lt = relay.layout_transform(A, "NHWC", "NHWC4c") + mod = relay.Function([A], lt) + + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + + +@pytest.mark.skipif( + tvm.testing.utils.IS_IN_CI, reason="Skip because GPU in CI doesn't support FP16" +) +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_layout_transform_to_block_nhwc(remote, target, dtype): + """Verification of the case NHWC4c->NHWC""" + input_shape = (1, 80, 80, 36, 4) + A = relay.var("data", shape=input_shape, dtype=dtype) + mean = relay.mean(A, axis=[1, 2], keepdims=True) + cast = relay.cast(mean, "float16") + lt = relay.layout_transform(cast, "NHWC4c", "NHWC") + mod = relay.Function([A], lt) + + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + + +if __name__ == "__main__": + test_layout_transform_to_block_nhwc(None, "opencl -device=adreno", "float16") diff --git a/tests/python/unittest/test_target_codegen_opencl.py b/tests/python/unittest/test_target_codegen_opencl.py index c25b3c2c86ea..bc2d0a84fd9d 100644 --- a/tests/python/unittest/test_target_codegen_opencl.py +++ b/tests/python/unittest/test_target_codegen_opencl.py @@ -185,8 +185,4 @@ def check_type_casting(ctx, n, dtype): if __name__ == "__main__": - test_opencl_ternary_expression() - test_opencl_inf_nan() - test_opencl_max() - test_opencl_erf() - test_opencl_type_casting() + tvm.testing.main() diff --git a/tests/python/unittest/test_target_texture_codegen_opencl.py b/tests/python/unittest/test_target_texture_codegen_opencl.py index 06876258e5d1..639159c495f0 100644 --- a/tests/python/unittest/test_target_texture_codegen_opencl.py +++ b/tests/python/unittest/test_target_texture_codegen_opencl.py @@ -1397,5 +1397,380 @@ class TestDepthwiseConv2dNCHWcKCRSk(BaseConv2DValidator): test_func = tvm.testing.parameter(depthwise_conv2d_NCHWc_KCRSk_acc32) +def simple_texture_to_scalar_common( + target, input_info, output_info, find_patterns, dtype, cast_type +): + def _compute(): + p0 = te.placeholder(input_info[1], name="p0", dtype=dtype) + p0_comp = te.compute(input_info[1], lambda *i: p0(*i), name="p0_comp") + if len(output_info[1]) == 4 and len(input_info[1]) == 5: + out = te.compute( + output_info[1], + lambda n, c, h, w: p0_comp[n][c // 4][h][w][c % 4].astype(cast_type), + name="out", + ) + elif len(output_info[1]) == 5 and len(input_info[1]) == 5: + out = te.compute( + output_info[1], + lambda n, c, h, w, cb: p0_comp[n][c][h][w][cb].astype(cast_type), + name="out", + ) + else: + raise Exception("Impossible case") + dummy_out = te.compute(output_info[1], lambda *i: out(*i), name="dummy_out") + return p0, dummy_out + + def _schedule(dummy_out): + from tvm.topi.adreno.utils import bind_data_copy + + s = te.create_schedule(dummy_out.op) + out = s[dummy_out].op.input_tensors[0] + p0_comp = s[out].op.input_tensors[0] + s[p0_comp].set_scope(input_info[0]) + bind_data_copy(s[p0_comp]) + s[out].set_scope(output_info[0]) + bind_data_copy(s[out]) + bind_data_copy(s[dummy_out]) + return s + + p0, dummy_out = _compute() + s = _schedule(dummy_out) + + fun = tvm.build(s, [p0, dummy_out], target) + dev = tvm.device(target, 0) + opencl_source = fun.imported_modules[0].get_source() + start_idx = 0 + for pattern in find_patterns: + start_idx = opencl_source.find(pattern, start_idx) + assert start_idx > -1 + + input_np = np.random.uniform(size=[i for i in input_info[1]]).astype(dtype) + input_tvm = tvm.nd.array(input_np, dev) + c = tvm.nd.empty(output_info[1], dtype, dev) + # Doesn't run OpenCL code for FP16 because GPUs in CI don't support FP16 inference + if cast_type == "float32": + fun(input_tvm, c) + # For output len == 5 it makes no sense to check the accuracy + if cast_type == "float32" and len(output_info[1]) == 4: + np_result = input_np.transpose(0, 2, 3, 1, 4) # NCHW4c -> NHWC4c + np_result = np.squeeze(np_result, axis=3) + np_result = np_result.transpose(0, 3, 1, 2) # NHWC -> NCHW + np.testing.assert_allclose(c.asnumpy(), np_result, rtol=1e-2, atol=1e-2) + + +class TestSimpleTextureToScalarFP16: + # (input [scope, shape], output [scope, shape], [find_patterns]) + input_info, output_info, find_patterns = tvm.testing.parameters( + # 1. Texture (NCHW4c) -> Cast(FP16) -> Buffer (NCHW) + ( + ["global.texture", (1, 1, 40, 40, 4)], + ["", (1, 4, 40, 40)], + [ + "float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)((((int)get_local_id(0)) % 40), (((((int)get_group_id(0)) & 1) * 20) + (((int)get_local_id(0)) / 40)))));", + "out[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = ((half)((float*)&v_)[(((int)get_group_id(0)) >> 1)]);", + ], + ), + # 2. Buffer (NCHW4c) -> Cast(FP16) -> Buffer (NCHW) + ( + ["", (1, 1, 40, 40, 4)], + ["", (1, 4, 40, 40)], + [ + "out[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = ((half)p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))]);" + ], + ), + # 3. Texture (NCHW4c) -> Cast(FP16) -> Texture (NCHW4c) + ( + ["global.texture", (1, 1, 40, 40, 4)], + ["global.texture", (1, 1, 40, 40, 4)], + [ + "float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)((((((int)get_group_id(0)) * 24) + ((int)get_local_id(0))) % 40), (((((int)get_group_id(0)) * 8) + (((int)get_local_id(0)) >> 3)) / 5))));", + "write_imageh(out, (int2)((((((int)get_group_id(0)) * 24) + ((int)get_local_id(0))) % 40), (((((int)get_group_id(0)) * 8) + (((int)get_local_id(0)) >> 3)) / 5)), (convert_half4(v_)));", + ], + ), + ) + dtype = tvm.testing.parameter("float32") + + @tvm.testing.parametrize_targets("opencl") + def test_simple_texture_to_scalar_fp16( + self, input_info, output_info, find_patterns, dtype, target + ): + simple_texture_to_scalar_common( + target, input_info, output_info, find_patterns, dtype, "float16" + ) + + +class TestSimpleTextureToScalarFP32: + # (input [scope, shape], output [scope, shape], [find_patterns]) + input_info, output_info, find_patterns = tvm.testing.parameters( + # 1. Texture (NCHW4c) -> Buffer (NCHW) + ( + ["global.texture", (1, 1, 40, 40, 4)], + ["", (1, 4, 40, 40)], + [ + "float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)((((int)get_local_id(0)) % 40), (((((int)get_group_id(0)) & 1) * 20) + (((int)get_local_id(0)) / 40)))));", + "out[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = ((float*)&v_)[(((int)get_group_id(0)) >> 1)];", + ], + ), + # 2. Buffer (NCHW4c) -> Buffer (NCHW) + ( + ["", (1, 1, 40, 40, 4)], + ["", (1, 4, 40, 40)], + [ + "out[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))];" + ], + ), + ) + dtype = tvm.testing.parameter("float32") + + @tvm.testing.parametrize_targets("opencl") + def test_simple_texture_to_scalar_fp32( + self, input_info, output_info, find_patterns, dtype, target + ): + simple_texture_to_scalar_common( + target, input_info, output_info, find_patterns, dtype, "float32" + ) + + +def texture_to_scalar_reuse_ssa_common( + target, input_info, output_info, find_patterns, dtype, cast_type +): + def _compute(): + p0 = te.placeholder(input_info[1], name="p0", dtype=dtype) + p0_comp = te.compute(input_info[1], lambda *i: p0(*i), name="p0_comp") + if len(output_info[1]) == 4 and len(input_info[1]) == 5: + out = te.compute( + output_info[1], + lambda n, c, h, w: p0_comp[n][c // 4][h][w][c % 4].astype(cast_type), + name="out", + ) + out2 = te.compute( + output_info[1], + lambda n, c, h, w: out[n][c][h][w] + + p0_comp[n][c // 4][h][w][c % 4].astype(cast_type), + name="out", + ) + elif len(output_info[1]) == 5 and len(input_info[1]) == 5: + out = te.compute( + output_info[1], + lambda n, c, h, w, cb: p0_comp[n][c][h][w][cb].astype(cast_type), + name="out", + ) + out2 = te.compute( + output_info[1], + lambda n, c, h, w, cb: out[n][c][h][w][cb] + + p0_comp[n][c][h][w][cb].astype(cast_type), + name="out", + ) + else: + raise Exception("Impossible case") + out_sum = te.compute(output_info[1], lambda *i: out(*i) + out2(*i), name="out_sum") + dummy_out = te.compute(output_info[1], lambda *i: out_sum(*i), name="dummy_out") + return p0, dummy_out + + def _schedule(dummy_out): + from tvm.topi.adreno.utils import bind_data_copy + + s = te.create_schedule(dummy_out.op) + out_sum = s[dummy_out].op.input_tensors[0] + out, out2 = s[out_sum].op.input_tensors + p0_comp = s[out].op.input_tensors[0] + s[p0_comp].set_scope(input_info[0]) + bind_data_copy(s[p0_comp]) + s[out].set_scope(output_info[0]) + s[out2].set_scope(output_info[0]) + s[out2].compute_inline() + s[out].compute_inline() + s[out_sum].set_scope(output_info[0]) + bind_data_copy(s[out_sum]) + bind_data_copy(s[dummy_out]) + return s + + p0, dummy_out = _compute() + s = _schedule(dummy_out) + + fun = tvm.build(s, [p0, dummy_out], target) + dev = tvm.device(target, 0) + opencl_source = fun.imported_modules[0].get_source() + start_idx = 0 + for pattern in find_patterns: + start_idx = opencl_source.find(pattern, start_idx) + assert start_idx > -1 + + input_np = np.random.uniform(size=[i for i in input_info[1]]).astype(dtype) + input_tvm = tvm.nd.array(input_np, dev) + c = tvm.nd.empty(output_info[1], dtype, dev) + # Doesn't run OpenCL code for FP16 because GPUs in CI don't support FP16 inference + if cast_type == "float32": + fun(input_tvm, c) + # For output len == 5 it makes no sense to check the accuracy + if cast_type == "float32" and len(output_info[1]) == 4: + np_result = input_np * 3 + np_result = np_result.transpose(0, 2, 3, 1, 4) # NCHW4c -> NHWC4c + np_result = np.squeeze(np_result, axis=3) + np_result = np_result.transpose(0, 3, 1, 2) # NHWC -> NCHW + np.testing.assert_allclose(c.asnumpy(), np_result, rtol=1e-2, atol=1e-2) + + +class TestTextureToScalarReuseSSAFP16: + # (input [scope, shape], output [scope, shape], [find_patterns]) + input_info, output_info, find_patterns = tvm.testing.parameters( + # 1. Texture (NCHW4c) -> Cast(FP16) -> Buffer (NCHW) + ( + ["global.texture", (1, 1, 40, 40, 4)], + ["", (1, 4, 40, 40)], + [ + "float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)((((int)get_local_id(0)) % 40), (((((int)get_group_id(0)) & 1) * 20) + (((int)get_local_id(0)) / 40)))));", + "out_sum[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = (((half)((float*)&v_)[(((int)get_group_id(0)) >> 1)]) + (((half)((float*)&v_)[(((int)get_group_id(0)) >> 1)]) + ((half)((float*)&v_)[(((int)get_group_id(0)) >> 1)])));", + ], + ), + # 2. Buffer (NCHW4c) -> Cast(FP16) -> Buffer (NCHW) + ( + ["", (1, 1, 40, 40, 4)], + ["", (1, 4, 40, 40)], + [ + "out_sum[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = (((half)p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))]) + (((half)p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))]) + ((half)p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))])));" + ], + ), + # 3. Texture (NCHW4c) -> Cast(FP16) -> Texture (NCHW4c) + ( + ["global.texture", (1, 1, 40, 40, 4)], + ["global.texture", (1, 1, 40, 40, 4)], + [ + "float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)((((((int)get_group_id(0)) * 24) + ((int)get_local_id(0))) % 40), (((((int)get_group_id(0)) * 8) + (((int)get_local_id(0)) >> 3)) / 5))));", + "write_imageh(out_sum, (int2)((((((int)get_group_id(0)) * 24) + ((int)get_local_id(0))) % 40), (((((int)get_group_id(0)) * 8) + (((int)get_local_id(0)) >> 3)) / 5)), ((convert_half4(v_)) + ((convert_half4(v_)) + (convert_half4(v_)))));", + ], + ), + ) + dtype = tvm.testing.parameter("float32") + + @tvm.testing.parametrize_targets("opencl") + def test_texture_to_scalar_reuse_ssa_fp16( + self, input_info, output_info, find_patterns, dtype, target + ): + texture_to_scalar_reuse_ssa_common( + target, input_info, output_info, find_patterns, dtype, "float16" + ) + + +class TestTextureToScalarReuseSSAFP32: + # (input [scope, shape], output [scope, shape], [find_patterns]) + input_info, output_info, find_patterns = tvm.testing.parameters( + # 1. Texture (NCHW4c) -> Buffer (NCHW) + ( + ["global.texture", (1, 1, 40, 40, 4)], + ["", (1, 4, 40, 40)], + [ + "float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)((((int)get_local_id(0)) % 40), (((((int)get_group_id(0)) & 1) * 20) + (((int)get_local_id(0)) / 40)))));", + "out_sum[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = (((float*)&v_)[(((int)get_group_id(0)) >> 1)] + (((float*)&v_)[(((int)get_group_id(0)) >> 1)] + ((float*)&v_)[(((int)get_group_id(0)) >> 1)]));", + ], + ), + # 2. Buffer (NCHW4c) -> Buffer (NCHW) + ( + ["", (1, 1, 40, 40, 4)], + ["", (1, 4, 40, 40)], + [ + "out_sum[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = (p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))] + (p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))] + p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))]));" + ], + ), + ) + dtype = tvm.testing.parameter("float32") + + @tvm.testing.parametrize_targets("opencl") + def test_texture_to_scalar_reuse_ssa_fp32( + self, input_info, output_info, find_patterns, dtype, target + ): + texture_to_scalar_reuse_ssa_common( + target, input_info, output_info, find_patterns, dtype, "float32" + ) + + +class TestLocalArrayToTexture: + # 1. conv2d(Texture(NCHW4c), Texture(OIHW4o)) -> local_array[4] -> Texture (NCHW4c) + input_shape1, input_shape2, output_shape, find_patterns = tvm.testing.parameters( + ( + (1, 1, 40, 40, 4), + (2, 4, 3, 3, 4), + (1, 2, 38, 38, 4), + [ + "float out_local[4];", + "float4 v_ = READ_IMAGEF(p1_comp, image_sampler, ((int2)((((((int)get_group_id(0)) * 14) + ((int)get_local_id(0))) % 38), ((((((int)get_group_id(0)) * 64) + (((int)get_local_id(0)) >> 1)) % 722) / 19))));", + "float4 v__1 = READ_IMAGEF(p2_comp, image_sampler, ((int2)(rw, ((((((((int)get_group_id(0)) * 32) + (((int)get_local_id(0)) >> 2)) / 361) * 12) + (rcb * 3)) + rh))));", + "out_local[cb_c] = (out_local[cb_c] + (((float*)&v_)[rcb] * ((float*)&v__1)[cb_c]));", + "write_imagef(out, (int2)((((((int)get_group_id(0)) * 14) + ((int)get_local_id(0))) % 38), (((((int)get_group_id(0)) * 64) + (((int)get_local_id(0)) >> 1)) / 19)), vload4(0, out_local + 0));", + ], + ), + ) + dtype = tvm.testing.parameter("float32") + + @tvm.testing.parametrize_targets("opencl") + def test_local_array_to_texture( + self, input_shape1, input_shape2, output_shape, find_patterns, dtype, target + ): + def _compute(): + p1 = te.placeholder(input_shape1, name="p1", dtype=dtype) + p1_comp = te.compute(input_shape1, lambda *i: p1(*i), name="p1_comp") + p2 = te.placeholder(input_shape2, name="p2", dtype=dtype) + p2_comp = te.compute(input_shape2, lambda *i: p2(*i), name="p2_comp") + KH, KW = input_shape2[2], input_shape2[3] + IC, ICB = input_shape1[1], input_shape1[4] + rh = te.reduce_axis((0, KH), name="rh") + rw = te.reduce_axis((0, KW), name="rw") + rc = te.reduce_axis((0, IC), name="rc") + rcb = te.reduce_axis((0, ICB), name="rcb") + out = te.compute( + output_shape, + lambda n, c, h, w, cb: te.sum( + (p1_comp[n, rc, h, w, rcb] * p2_comp[c, rc * ICB + rcb, rh, rw, cb]).astype( + dtype + ), + axis=[rh, rw, rc, rcb], + ), + name="out", + ) + dummy_out = te.compute(output_shape, lambda *i: out(*i), name="dummy_out") + return p1, p2, dummy_out + + def _schedule(dummy_out): + from tvm.topi.adreno.utils import bind_data_copy + + s = te.create_schedule(dummy_out.op) + out = s[dummy_out].op.input_tensors[0] + p1_comp, p2_comp = s[out].op.input_tensors + bind_data_copy(s[p1_comp]) + s[p1_comp].set_scope("global.texture") + bind_data_copy(s[p2_comp]) + s[p2_comp].set_scope("global.texture") + OL = s.cache_write(out, "local") + n, c, h, w, cb = s[out].op.axis + fused = s[out].fuse(n, c, h, w) + bx, tx = s[out].split(fused, 128) + s[out].reorder(bx, tx, cb) + s[out].vectorize(cb) + s[out].set_scope("global.texture") + s[out].bind(bx, te.thread_axis("blockIdx.x")) + s[out].bind(tx, te.thread_axis("threadIdx.x")) + s[OL].compute_at(s[out], tx) + bind_data_copy(s[dummy_out]) + return s + + p1, p2, dummy_out = _compute() + s = _schedule(dummy_out) + + fun = tvm.build(s, [p1, p2, dummy_out], target) + dev = tvm.device(target, 0) + opencl_source = fun.imported_modules[0].get_source() + start_idx = 0 + for pattern in find_patterns: + start_idx = opencl_source.find(pattern, start_idx) + assert start_idx > -1 + + input_np1 = np.random.uniform(size=[i for i in input_shape1]).astype(dtype) + input_np2 = np.random.uniform(size=[i for i in input_shape2]).astype(dtype) + input_tvm1 = tvm.nd.array(input_np1, dev) + input_tvm2 = tvm.nd.array(input_np2, dev) + c = tvm.nd.empty(output_shape, dtype, dev) + fun(input_tvm1, input_tvm2, c) + + if __name__ == "__main__": tvm.testing.main()