diff --git a/python/tvm/topi/adreno/utils.py b/python/tvm/topi/adreno/utils.py index de0505af03d4..1a1cc747faac 100644 --- a/python/tvm/topi/adreno/utils.py +++ b/python/tvm/topi/adreno/utils.py @@ -525,28 +525,27 @@ def bind_data_copy(stage, axis_to_vectorize=None): stage.bind(block, te.thread_axis("blockIdx.z")) stage.bind(thread, te.thread_axis("threadIdx.z")) else: - axes = stage.op.axis - fused = stage.fuse(*axes[:-1]) - if shape[-1] <= 32: + if shape[-1] == 4: + axes = stage.op.axis + fused = stage.fuse(*axes[:-1]) ftc = numpy.prod(shape[:-1]) div = get_div(ftc, 64) block, thread = stage.split(fused, factor=div) stage.bind(block, te.thread_axis("blockIdx.x")) stage.bind(thread, te.thread_axis("threadIdx.x")) - if shape[-1] == 4: - stage.vectorize(axes[-1]) - # 1024 is the maximum work group size for Adreno devices. - # See: CL_DEVICE_MAX_WORK_GROUP_SIZE - elif shape[-1] > 1024: - ftc = numpy.prod(shape[:-1]) - div = get_div(ftc, 1024) - by, ty = stage.split(axes[-1], factor=div) - stage.bind(fused, te.thread_axis("blockIdx.x")) - stage.bind(by, te.thread_axis("blockIdx.y")) - stage.bind(ty, te.thread_axis("threadIdx.y")) + stage.vectorize(axes[-1]) else: - stage.bind(fused, te.thread_axis("blockIdx.x")) - stage.bind(*axes[-1:], te.thread_axis("threadIdx.x")) + ftc = numpy.prod(shape) + vthread = get_div(ftc, 8) + fused = stage.fuse(*stage.op.axis) + ftc = ftc / vthread + # 1024 is a maximum work group size on the most Adreno GPU + num_thread = get_div(ftc, 1024 // vthread) + a, b = stage.split(fused, factor=num_thread) + a, c = stage.split(a, factor=vthread) + stage.bind(c, te.thread_axis("vthread")) + stage.bind(a, te.thread_axis("blockIdx.x")) + stage.bind(b, te.thread_axis("threadIdx.x")) def get_texture_storage(shape): diff --git a/src/relay/transforms/annotate_texture_storage.cc b/src/relay/transforms/annotate_texture_storage.cc index 6904c6b5d7cc..277c5e1da424 100644 --- a/src/relay/transforms/annotate_texture_storage.cc +++ b/src/relay/transforms/annotate_texture_storage.cc @@ -206,7 +206,9 @@ class StorageInfo : private transform::DeviceAwareExprVisitor { } } - primitive_supports_texture_ = SupportsTextureStorage(call); + if (!primitive_supports_texture_) { + primitive_supports_texture_ = SupportsTextureStorage(call); + } for (auto& arg : call->args) { Visit(arg); @@ -362,6 +364,12 @@ class StorageInfo : private transform::DeviceAwareExprVisitor { bool SupportsTextureStorage(const CallNode* call) const { bool supports_texture_storage = false; + // we need to verify only entry functions since one of entry op defines main schedule + for (const auto& arg : call->args) { + if (!arg.as()) { + return false; + } + } if (auto attrs = call->attrs.as()) { if (attrs->data_layout == "NCHW4c" && attrs->kernel_layout == "OIHW4o") { supports_texture_storage = true; diff --git a/tests/python/relay/opencl_texture/test_conv2d_nchw_texture.py b/tests/python/relay/opencl_texture/test_conv2d_nchw_texture.py index c73e411a700e..5198cbdf6bc6 100644 --- a/tests/python/relay/opencl_texture/test_conv2d_nchw_texture.py +++ b/tests/python/relay/opencl_texture/test_conv2d_nchw_texture.py @@ -1074,3 +1074,196 @@ def test_conv2d_winograd_non_rect(target, dtype): ) matches = re.findall("winograd", graph) assert len(matches) > 0 + + +# function repeat, params scope are different in reused functions +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_injective_nwo_inputs1(target, dtype): + """ + Use case for verification of stability of annotation primary functions + having several ops accepting data outside of Primary function + The visiting of ops during traversing of graph inside primary function + can depend on order of relay graph creation. Thus the annotation mechanism + should be reliable for graph traversal order + The current decision if Prim Function support textures or not depend on + *any* op accepting input of the function and if op support textures + Input + / \ + layout_transform (NCHW->NCHW4c) | + | / + conv2d (1) / + | / + conv2d (2) mean / + / \ / <- Primary function several head ops + (1)add (2)layout_transform | + | (NCHW4c->NCHW) | + | | \ / + | | (3) add + | | | + layout_transform \ / + (NCHW4c->NCHW) \ / + \ mul + \ / + add + + This test verifies a case when the latest op which is visited is (3) and does not + support textures, but there is (1) supporting textures, thus the whole func will + support textures + """ + input_shape = (1, 4, 40, 40) + filter_shape1 = (4, 4, 3, 3) + filter_shape2 = (4, 4, 3, 3) + filter_shape3 = (4, 4, 3, 3) + A = relay.var("data", shape=input_shape, dtype=dtype) + W1 = relay.var("weight1", shape=filter_shape1, dtype=dtype) + W2 = relay.var("weight2", shape=filter_shape2, dtype=dtype) + mean = relay.mean(A, axis=1, keepdims=True) + conv1 = relay.nn.conv2d( + A, + W1, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[1, 1, 1, 1], + strides=[1, 1], + out_dtype=dtype, + channels=4, + kernel_size=(3, 3), + ) + + conv2 = relay.nn.conv2d( + conv1, + W2, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[1, 1, 1, 1], + strides=[1, 1], + out_dtype=dtype, + channels=4, + kernel_size=(3, 3), + ) + + ad3 = relay.op.add(conv1, conv2) + ad1 = relay.op.add(mean, conv1) + ad2 = relay.op.multiply(ad1, conv2) + ad4 = relay.op.add(ad3, ad2) + + mod = relay.Function([A, W1, W2], ad4) + np.random.seed(0) + initializer = relay.testing.init.Xavier() + filter_data1 = np.zeros(filter_shape1).astype(dtype) + filter_data2 = np.zeros(filter_shape2).astype(dtype) + initializer("weight", filter_data1) + initializer("weight", filter_data2) + params1 = { + "weight1": tvm.nd.array(filter_data1), + "weight2": tvm.nd.array(filter_data2), + } + + static_memory_scope = [ + "global", + "global.texture", + "global.texture-nhwc", + "global.texture", + "global.texture-nhwc", + "global.texture", + "global", + "global", + ] + build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope) + + +# function repeat, params scope are different in reused functions +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_injective_nwo_inputs2(target, dtype): + """ + Use case for verification of stability of annotation primary functions + having several ops accepting data outside of Primary function + The visiting of ops during traversing of graph inside primary function + can depend on order of relay graph creation. Thus the annotation mechanism + should be reliable for graph traversal order + The current decision if Prim Function support textures or not depend on + *any* op accepting input of the function and if op support textures + Input + / \ + layout_transform (NCHW->NCHW4c) | + | / + conv2d (1) / + | / + conv2d (2) mean / + / \ / <- Primary function several head ops + (1)add (2)layout_transform | + | (NCHW4c->NCHW) | + | | \ / + | | (3) add + | | | + layout_transform \ / + (NCHW4c->NCHW) \ / + \ mul + \ / + add + + This test verifies a case when the latest op which is (1), it supports textures + an whole prim function is considered as a func working with textures + """ + input_shape = (1, 4, 40, 40) + filter_shape1 = (4, 4, 3, 3) + filter_shape2 = (4, 4, 3, 3) + filter_shape3 = (4, 4, 3, 3) + A = relay.var("data", shape=input_shape, dtype=dtype) + W1 = relay.var("weight1", shape=filter_shape1, dtype=dtype) + W2 = relay.var("weight2", shape=filter_shape2, dtype=dtype) + mean = relay.mean(A, axis=1, keepdims=True) + conv1 = relay.nn.conv2d( + A, + W1, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[1, 1, 1, 1], + strides=[1, 1], + out_dtype=dtype, + channels=4, + kernel_size=(3, 3), + ) + + conv2 = relay.nn.conv2d( + conv1, + W2, + data_layout="NCHW", + kernel_layout="OIHW", + padding=[1, 1, 1, 1], + strides=[1, 1], + out_dtype=dtype, + channels=4, + kernel_size=(3, 3), + ) + + ad3 = relay.op.add(conv1, conv2) + ad1 = relay.op.add(mean, conv1) + ad2 = relay.op.multiply(ad1, conv2) + ad4 = relay.op.add(ad2, ad3) + + mod = relay.Function([A, W1, W2], ad4) + np.random.seed(0) + initializer = relay.testing.init.Xavier() + filter_data1 = np.zeros(filter_shape1).astype(dtype) + filter_data2 = np.zeros(filter_shape2).astype(dtype) + initializer("weight", filter_data1) + initializer("weight", filter_data2) + params1 = { + "weight1": tvm.nd.array(filter_data1), + "weight2": tvm.nd.array(filter_data2), + } + + static_memory_scope = [ + "global", + "global.texture", + "global.texture-nhwc", + "global.texture", + "global", + "global.texture-nhwc", + "global.texture", + "global", + ] + build_run_compare(mod, params1, {"data": input_shape}, dtype, target, static_memory_scope)