From d0568812d8640ee6ef57b452ff52d573e9a1cd66 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Mon, 9 Jun 2025 15:59:12 -0700 Subject: [PATCH] [ET-VK] New implementation of `cat` operator ## Changes * Introduce `concat_texture.glsl` and `concat_buffer.glsl` to implement the `torch.cat` operator * Introduce `Concat.cpp` to replace `Cat.cpp` * Fix a bug with channels-packed buffer tensors where input data would be copied incorrectly with multiple dims have a stride of 1 ## Motivation > * Introduce `concat_texture.glsl` and `concat_buffer.glsl` to implement the `torch.cat` operator > * Introduce `Concat.cpp` to replace `Cat.cpp` The existing implementation of `torch.cat` uses the copy_channel_offset` shaders. However, these shaders have a critical bug where the output tensor is passed in separately with difference access types, i.e. ``` graph.execute_nodes().emplace_back(new DispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), global_size, local_size, // Inputs and Outputs { {out, vkapi::kWrite}, {out, vkapi::kRead}, {in, vkapi::kRead}, }, ``` This creates many validation layer errors because the memory barriers for the resource cannot be formed properly. The shader essentially relies on undefined behaviour to work correctly. The result is that the `cat` operator produces incorrect result on many platforms. Rather than fix the `copy_offset` shaders, I decided to just introduce new shaders to perform the concat operation. The new implementation handles both buffer and texture inputs and is agnostic to memory layout. Differential Revision: [D76305343](https://our.internmc.facebook.com/intern/diff/D76305343/) [ghstack-poisoned] --- .../runtime/graph/ops/glsl/concat_buffer.glsl | 67 +++++++++ .../runtime/graph/ops/glsl/concat_buffer.yaml | 14 ++ .../graph/ops/glsl/concat_texture.glsl | 129 ++++++++++++++++++ .../graph/ops/glsl/concat_texture.yaml | 14 ++ .../runtime/graph/ops/glsl/indexing_utils.h | 23 ---- .../graph/ops/glsl/nchw_to_buffer.glsl | 4 +- .../vulkan/runtime/graph/ops/impl/Cat.cpp | 98 ------------- .../vulkan/runtime/graph/ops/impl/Concat.cpp | 128 +++++++++++++++++ .../vulkan/runtime/graph/ops/impl/Staging.cpp | 5 +- backends/vulkan/test/op_tests/cases.py | 5 +- 10 files changed, 362 insertions(+), 125 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/concat_buffer.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/concat_buffer.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/concat_texture.yaml delete mode 100644 backends/vulkan/runtime/graph/ops/impl/Cat.cpp create mode 100644 backends/vulkan/runtime/graph/ops/impl/Concat.cpp diff --git a/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.glsl new file mode 100644 index 00000000000..dabfb79be37 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.glsl @@ -0,0 +1,67 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("buffer")} +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} + +$for i in range(NUM_INPUTS): + ${layout_declare_tensor(B, "r", "t_in" + str(i + 1), DTYPE, "buffer")} + +${layout_declare_ubo(B, "int", "concat_dim")} + +${layout_declare_ubo(B, "ivec4", "out_sizes")} +${layout_declare_ubo(B, "ivec4", "out_strides")} + +$for i in range(NUM_INPUTS): + ${layout_declare_ubo(B, "ivec4", "in" + str(i+1) + "_sizes")} + ${layout_declare_ubo(B, "ivec4", "in" + str(i+1) + "_strides")} + +${layout_declare_ubo(B, "int", "out_numel")} + +${layout_declare_spec_const(C, "int", "out_packed_dim", "DEFAULT_LAYOUT")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const int out_bufi = ivec3(gl_GlobalInvocationID).x; + if (out_bufi >= out_numel) { + return; + } + + // Convert buffer linear index to 4-D tensor index for output + const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim); + + // Determine which input tensor to read from + ivec4 in_tidx = out_tidx; + + $for i in range(NUM_INPUTS): + // Check if the index at the concat dim is within bounds of the input tensor + // If so, read from that input tensor and write to output + if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) { + int in_bufi = tidx_to_bufi(in_tidx, in${i+1}_strides); + t_out[out_bufi] = t_in${i+1}[in_bufi]; + return; + } + // otherwise, decrement the index at the concat dim + else { + in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim]; + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.yaml new file mode 100644 index 00000000000..39f96df5e90 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.yaml @@ -0,0 +1,14 @@ +concat_buffer: + parameter_names_with_default_values: + DTYPE: float + NUM_INPUTS: 2 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: concat_1_buffer + NUM_INPUTS: 1 + - NAME: concat_2_buffer + - NAME: concat_3_buffer + NUM_INPUTS: 3 diff --git a/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl new file mode 100644 index 00000000000..dac6266bf67 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl @@ -0,0 +1,129 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} +#define T ${buffer_scalar_type(DTYPE)} + +#define USING_TEXTURE3D + +layout(std430) buffer; + +#include "indexing_utils.h" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")} + +$for i in range(NUM_INPUTS): + ${layout_declare_tensor(B, "r", "t_in" + str(i + 1), DTYPE, "texture3d")} + +${layout_declare_ubo(B, "int", "concat_dim")} + +$in_metadata = "" +$for i in range(NUM_INPUTS): + $in_metadata += "ivec4 in" + str(i + 1) + "_sizes;\n" + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ${in_metadata} +}; + +${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} +const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); +const lowp int out_packed_dim = unhash_packed_dim(out_layout); + +$for i in range(NUM_INPUTS): + ${layout_declare_spec_const(C, "int", "in" + str(i+1) + "_layout", "DEFAULT_LAYOUT")} + const lowp ivec4 in${i+1}_axis_map = unhash_axis_map(in${i+1}_layout); + const lowp int in${i+1}_packed_dim = unhash_packed_dim(in${i+1}_layout); + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Check if we can use the fast path (no texel merging required) +bool can_use_fast_path() { + // Fast path is possible when: + // 1. The concat dimension is not the packed dimension, or + // 2. The concat dimension is the packed dimension but both input tensors have dimensions + // that are multiples of 4 along the packed dimension + if (concat_dim != out_packed_dim) { + return true; + } + + // Check if all input tensors have dimensions that are multiples of 4 along the packed dimension + bool all_concat_dim_size_multiple_of_4 = true; + $for i in range(NUM_INPUTS): + all_concat_dim_size_multiple_of_4 = + all_concat_dim_size_multiple_of_4 && + (in${i+1}_sizes[concat_dim] % 4 == 0); + + return all_concat_dim_size_multiple_of_4; +} + +void main() { + const ivec3 lpos = ivec3(gl_GlobalInvocationID); + ivec4 out_tidx = lpos_to_tidx(lpos, out_sizes, out_axis_map.w, out_packed_dim); + + if (any(greaterThanEqual(out_tidx, out_sizes))) { + return; + } + + if (can_use_fast_path()) { + // Fast path: No texel merging required + ivec4 in_tidx = out_tidx; + + $for i in range(NUM_INPUTS): + // For each input tensor, check if the tensor index is within bounds. If + // so, read the texel from the input tensor and write it to the output + if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) { + const ivec3 in_pos = tidx_to_pos(in_tidx, in${i+1}_sizes, in${i+1}_axis_map, in${i+1}_packed_dim); + const VEC4_T in_texel = load_texel(t_in${i+1}, in_pos); + write_texel_lpos(t_out, lpos, in_texel, out_axis_map); + return; + } + // Otherwise, adjust the index along the concat dimension and try the next + // input tensor. + else { + in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim]; + } + } + else { + // Slow path: Texel merging required + VEC4_T out_texel = VEC4_T(0); + + // Process each element in the output texel individually + for (int texel_i = 0; texel_i < 4; ++texel_i) { + ivec4 curr_out_tidx = out_tidx; + curr_out_tidx[out_packed_dim] += texel_i; + + // Skip if we're out of bounds + if (curr_out_tidx[out_packed_dim] >= out_sizes[out_packed_dim]) { + continue; + } + + ivec4 in_tidx = curr_out_tidx; + $for i in range(NUM_INPUTS): + // For each input tensor, check if the tensor index is within bounds. If + // so, read the corresponding texel element from the input tensor and + // write it to the output texel. + if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) { + const ivec4 in_posi = tidx_to_posi(in_tidx, in${i+1}_sizes, in${i+1}_axis_map, in${i+1}_packed_dim); + out_texel[texel_i] = load_texel(t_in${i+1}, in_posi.xyz)[in_posi.w]; + continue; + } + // Otherwise, adjust the index along the concat dimension and try the + // next input tensor. + else { + in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim]; + } + } + + write_texel_lpos(t_out, lpos, out_texel, out_axis_map); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/concat_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.yaml new file mode 100644 index 00000000000..ed5003382a1 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.yaml @@ -0,0 +1,14 @@ +concat_texture: + parameter_names_with_default_values: + DTYPE: float + NUM_INPUTS: 2 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: concat_1_texture3d + NUM_INPUTS: 1 + - NAME: concat_2_texture3d + - NAME: concat_3_texture3d + NUM_INPUTS: 3 diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index 2b41d2b7e1a..b74caf11848 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -68,21 +68,6 @@ */ #define mod4(x) ((x) & 3) -/* - * Find the packed dimension of a tensor given its strides. The packed dimension - * is the "fastest moving" dimension which will have a stride of 1. - */ -int find_packed_dim(const ivec4 strides) { - int packed_dim = 0; - for (int i = 0; i <= 3; i++) { - if (strides[i] == 1) { - packed_dim = i; - break; - } - } - return packed_dim; -} - /* * Get the staging buffer indices that contain the data of the texel that * corresponds to the provided tensor index. Since the texel have 4 elements, @@ -144,14 +129,6 @@ ivec4 bufi_to_tidx(int bufi, const ivec4 strides, const int packed_dim) { return idx; } -// Convenience overload of the above function, which will determine the packed -// dim from the strides automatically so it doesn't have to be passed in as a -// function argument. -ivec4 bufi_to_tidx(const int bufi, const ivec4 strides) { - int packed_dim = find_packed_dim(strides); - return bufi_to_tidx(bufi, strides, packed_dim); -} - int tidx_to_bufi(const ivec4 tidx, ivec4 strides) { return tidx.x * strides.x + tidx.y * strides.y + tidx.z * strides.z + tidx.w * strides.w; diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl index ba4e4dd9dd9..e5fbc42c27b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl @@ -28,7 +28,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; // This constant is unused in this shader but is kept so that the signature is // consistent with nchw_to_image. -${layout_declare_spec_const(C, "int", "UNUSED_layout", "0")} +${layout_declare_spec_const(C, "int", "packed_dim", "0")} ${layout_declare_spec_const(C, "int", "transpose_hw", "0")} void main() { @@ -37,7 +37,7 @@ void main() { return; } - ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides); + ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, packed_dim); ivec4 sizes = out_sizes; if (transpose_hw == 1) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Cat.cpp b/backends/vulkan/runtime/graph/ops/impl/Cat.cpp deleted file mode 100644 index 25a0ff9a7f5..00000000000 --- a/backends/vulkan/runtime/graph/ops/impl/Cat.cpp +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include -#include -#include -#include -#include - -namespace vkcompute { - -void add_cat_default_node( - ComputeGraph& graph, - ValueRef in_list_ref, - ValueRef dim_ref, - ValueRef out) { - ValueListPtr input_list = graph.get_value_list(in_list_ref); - int64_t dim = graph.extract_scalar(dim_ref); - vTensorPtr t_out = graph.get_tensor(out); - - const auto packed_dim = t_out->packed_dim(); - const auto packed_dim_index = static_cast(kWidth4D - packed_dim); - - DimIndex dim_index = normalize_to_dim_index(*t_out, dim); - // Index of dimension to be concatenated in (w, h, c * b) coordinate system - const auto dim_xyz_index = std::min(2, -dim_index - 1); - - if (dim_index > kWidth4D || dim_index < kBatch4D) { - VK_THROW("Unexpected value of dim_index=", dim_index); - } - - utils::ivec4 src_offset = utils::make_ivec4({0, 0, 0, 0}, false); - utils::ivec4 dst_offset = utils::make_ivec4({0, 0, 0, 0}, false); - - const bool is_concat_channel = (dim_index == kChannel4D); - - // if concatenating channels - if (is_concat_channel) { - // set destination offset w as channel size of the output tensor - dst_offset[3] = dim_at(t_out->sizes(), kChannel4D); - } - - for (ValueRef input_ref : *input_list) { - const vTensorPtr t_in = graph.get_tensor(input_ref); - const utils::ivec3 range = t_in->logical_limits(); - const auto in_channel_size = dim_at(t_in->sizes(), kChannel4D); - // if concatenating same dimension as the packed dimension - if (dim_index == packed_dim_index) { - // if concatenating channels, use add_copy_channel_offset_node function as - // add_copy_packed_dim_offset_node does not support channel packing - if (is_concat_channel) { - add_copy_channel_offset_node( - graph, - input_ref, - in_channel_size, - src_offset[2], - dst_offset[2], - out); - dst_offset[dim_xyz_index] += in_channel_size; - } else { - // src_offset[3] is not used now but will be used in the future when - // add_copy_packed_dim_offset_node will support channel packing - // - // set source offset w as channel size of the output tensor if - // concatenating channels - src_offset[3] = is_concat_channel ? in_channel_size : 0; - add_copy_packed_dim_offset_node( - graph, input_ref, range, src_offset, dst_offset, out); - dst_offset[dim_xyz_index] += dim_at(t_in->sizes(), packed_dim_index); - } - } else { - // set source offset w as channel size of the output tensor if - // concatenating channels - src_offset[3] = is_concat_channel ? in_channel_size : 0; - add_copy_offset_node( - graph, input_ref, range, src_offset, dst_offset, out, true, false); - dst_offset[dim_xyz_index] += - is_concat_channel ? in_channel_size : range[dim_xyz_index]; - } - } -} - -void cat_default(ComputeGraph& graph, const std::vector& args) { - add_cat_default_node(graph, args[0], args[1], args[2]); -} - -REGISTER_OPERATORS { - VK_REGISTER_OP(aten.cat.default, cat_default); -} - -} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Concat.cpp b/backends/vulkan/runtime/graph/ops/impl/Concat.cpp new file mode 100644 index 00000000000..6335aa24808 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Concat.cpp @@ -0,0 +1,128 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include + +namespace vkcompute { + +void add_concat_node( + ComputeGraph& graph, + const ValueRef tensors_ref, + const ValueRef dim_ref, + const ValueRef out) { + std::vector in_value_refs; + + { + const ValueListPtr tensors = graph.get_value_list(tensors_ref); + + VK_CHECK_COND( + tensors->size() <= 3, + "Currently only concatenation of <= 3 tensors is supported"); + + for (const ValueRef in : *tensors) { + in_value_refs.push_back(in); + } + } + + const int64_t dim = graph.extract_scalar(dim_ref); + + const int64_t ndim = graph.dim_of(in_value_refs.at(0)); + int64_t normalized_dim = dim; + if (normalized_dim < 0) { + normalized_dim += ndim; + } + + const int64_t dim_whcn = nchw_dim_to_whcn_dim(normalized_dim, ndim); + const ValueRef dim_whcn_ref = graph.get_or_add_value_for_int(dim_whcn); + + vkapi::ParamsBindList param_buffers = { + graph.get_or_create_int_param_buffer(dim_whcn_ref, 0)}; + + std::vector push_constants; + vkapi::SpecVarList spec_vars; + + if (graph.is_buffer_storage(out)) { + param_buffers.append(graph.sizes_ubo(out)); + param_buffers.append(graph.strides_ubo(out)); + + for (const ValueRef in_ref : in_value_refs) { + param_buffers.append(graph.sizes_ubo(in_ref)); + param_buffers.append(graph.strides_ubo(in_ref)); + } + + param_buffers.append(graph.numel_ubo(out)); + + spec_vars = {graph.packed_dim_of(out)}; + } else { + push_constants = {graph.sizes_pc_of(out)}; + + spec_vars = {graph.hashed_layout_of(out)}; + + for (const ValueRef in_ref : in_value_refs) { + push_constants.push_back(graph.sizes_pc_of(in_ref)); + spec_vars.append(graph.hashed_layout_of(in_ref)); + } + } + + std::string kernel_name = "concat"; + if (in_value_refs.size() == 1) { + kernel_name += "_1"; + } else if (in_value_refs.size() == 2) { + kernel_name += "_2"; + } else if (in_value_refs.size() == 3) { + kernel_name += "_3"; + } + if (graph.is_buffer_storage(out)) { + kernel_name += "_buffer"; + } else { + kernel_name += "_texture3d"; + } + + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {in_value_refs, vkapi::kRead}}, + // Parameter buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + nullptr)); +} + +void cat_tensor(ComputeGraph& graph, const std::vector& args) { + // Extract arguments + const ValueRef tensors_ref = args.at(0); + const ValueRef dim_ref = args.at(1); + const ValueRef out = args.at(2); + + // Add concat node + add_concat_node(graph, tensors_ref, dim_ref, out); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.cat.default, cat_tensor); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index f429ab0fc25..c998a8da66f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -30,13 +30,16 @@ void add_staging_to_tensor_node( *graph.get_tensor(out_tensor), graph.int8_buffers_enabled()); std::vector pcs; + vkapi::SpecVarList spec_vars; if (graph.is_buffer_storage(out_tensor)) { pcs = { graph.sizes_pc_of(out_tensor), graph.strides_pc_of(out_tensor), graph.numel_pc_of(out_tensor)}; + spec_vars = {graph.packed_dim_of(out_tensor)}; } else { pcs = {graph.sizes_pc_of(out_tensor)}; + spec_vars = {graph.hashed_layout_of(out_tensor)}; } graph.execute_nodes().emplace_back(new DynamicDispatchNode( @@ -51,7 +54,7 @@ void add_staging_to_tensor_node( // Push Constants pcs, // Specialization Constants - {graph.hashed_layout_of(out_tensor)}, + {spec_vars}, // Resize Args {}, // Resizing Logic diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 277daa60451..b6200a1ac7e 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1106,9 +1106,12 @@ def get_cat_inputs(): ) test_suite.layouts = [ "utils::kWidthPacked", - "utils::kHeightPacked", "utils::kChannelsPacked", ] + test_suite.storage_types = [ + "utils::kTexture3D", + "utils::kBuffer", + ] test_suite.data_gen = "make_seq_tensor" test_suite.dtypes = ["at::kFloat"] return test_suite