Skip to content

Commit 93b38be

Browse files
morgolockSS-JIA
authored andcommitted
Vulkan Q8 Conv2D: specialize shader on static parameters and tensor sizes
This change moves all fixed Conv2D parameters (kernel shape, stride, padding, dilation, groups) and the input/output tensor dimensions into Vulkan specialization constants. By making these values compile-time constants, the backend can generate more optimized pipelines, eliminate generic fallback paths, and reduce dynamic indexing overhead. This significantly improves performance across large and compute-intensive convolution workloads. Signed-off-by: Pablo Marquez Tello <pablo.tello@arm.com> Change-Id: I3efe3de80dece91341ae4111bef1254c6779a1db
1 parent 42e3222 commit 93b38be

18 files changed

+510
-124
lines changed

backends/vulkan/runtime/graph/ops/glsl/col2im.glsl

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,40 @@ ${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_arr
3535
${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)}
3636

3737
// Sizes of the convolution output image
38-
${layout_declare_ubo(B, "ivec4", "output_sizes")}
38+
//${layout_declare_ubo(B, "ivec4", "output_sizes")}
3939
// Sizes of the convolution input image
40-
${layout_declare_ubo(B, "ivec4", "input_sizes")}
40+
//${layout_declare_ubo(B, "ivec4", "input_sizes")}
4141
// Sizes of the im2col matrix of the convolution output
4242
${layout_declare_ubo(B, "ivec4", "matrix_sizes")}
4343

44-
${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")}
44+
//${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")}
45+
46+
${layout_declare_spec_const(C, "int", "apply_bias", "1")}
47+
${layout_declare_spec_const(C, "int", "conv2d_params_stride_x", "1")}
48+
${layout_declare_spec_const(C, "int", "conv2d_params_stride_y", "1")}
49+
${layout_declare_spec_const(C, "int", "conv2d_params_padding_x", "1")}
50+
${layout_declare_spec_const(C, "int", "conv2d_params_padding_y", "1")}
51+
${layout_declare_spec_const(C, "int", "conv2d_params_dilation_x", "1")}
52+
${layout_declare_spec_const(C, "int", "conv2d_params_dilation_y", "1")}
53+
${layout_declare_spec_const(C, "int", "conv2d_params_kernel_size_x", "1")}
54+
${layout_declare_spec_const(C, "int", "conv2d_params_kernel_size_y", "1")}
55+
${layout_declare_spec_const(C, "int", "conv2d_params_in_channels_per_group", "1")}
56+
${layout_declare_spec_const(C, "int", "conv2d_params_out_channels_per_group", "1")}
57+
${layout_declare_spec_const(C, "int", "conv2d_params_K4_per_group", "1")}
58+
${layout_declare_spec_const(C, "int", "conv2d_params_K4", "1")}
59+
${layout_declare_spec_const(C, "int", "conv2d_params_K_per_group", "1")}
60+
${layout_declare_spec_const(C, "int", "conv2d_params_logical_K", "1")}
61+
${layout_declare_spec_const(C, "int", "conv2d_params_logical_K_per_group", "1")}
62+
${layout_declare_spec_const(C, "int", "conv2d_params_groups", "1")}
63+
64+
${layout_declare_spec_const(C, "int", "output_x", "1")}
65+
${layout_declare_spec_const(C, "int", "output_y", "1")}
66+
${layout_declare_spec_const(C, "int", "output_z", "1")}
67+
${layout_declare_spec_const(C, "int", "output_w", "1")}
68+
${layout_declare_spec_const(C, "int", "input_x", "1")}
69+
${layout_declare_spec_const(C, "int", "input_y", "1")}
70+
${layout_declare_spec_const(C, "int", "input_z", "1")}
71+
${layout_declare_spec_const(C, "int", "input_w", "1")}
4572

4673
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4774

@@ -79,6 +106,9 @@ void main() {
79106
const int n4 = int(gl_GlobalInvocationID.x);
80107
const int m4 = int(gl_GlobalInvocationID.y);
81108

109+
const ivec4 output_sizes = ivec4(int(output_x), int(output_y), int(output_z), int(output_w));
110+
const ivec4 input_sizes = ivec4(int(input_x), int(input_y), int(input_z), int(input_w));
111+
82112
const int n = mul_4(n4);
83113
const int m = mul_4(m4);
84114

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8_utils.glslh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ void perform_conv1d(
143143
const WeightRow weight_row) {
144144
for (int out_w = 0; out_w < 4; ++out_w) {
145145
[[unroll]] for (int kx = 0; kx < weight_row.len; ++kx) {
146-
const int in_w = out_w * conv2d_params.stride.x;
146+
const int in_w = out_w * conv2d_params_stride_x;
147147
out_block.data[out_w] = fma(
148148
input_window.data[in_w + kx],
149149
weight_row.data[kx],

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_q8ta_q8csw_q8to.glsl

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ ${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_arra
3434
${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)}
3535
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)}
3636

37-
${layout_declare_ubo(B, "ivec4", "output_sizes")}
38-
${layout_declare_ubo(B, "ivec4", "input_sizes")}
39-
${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")}
37+
//${layout_declare_ubo(B, "ivec4", "output_sizes")}
38+
//${layout_declare_ubo(B, "ivec4", "input_sizes")}
39+
//${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")}
4040

4141
layout(push_constant) uniform restrict Block {
4242
float input_scale;
@@ -48,11 +48,42 @@ layout(push_constant) uniform restrict Block {
4848
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4949

5050
${layout_declare_spec_const(C, "int", "apply_bias", "1")}
51+
${layout_declare_spec_const(C, "int", "conv2d_params_stride_x", "1")}
52+
${layout_declare_spec_const(C, "int", "conv2d_params_stride_y", "1")}
53+
${layout_declare_spec_const(C, "int", "conv2d_params_padding_x", "1")}
54+
${layout_declare_spec_const(C, "int", "conv2d_params_padding_y", "1")}
55+
${layout_declare_spec_const(C, "int", "conv2d_params_dilation_x", "1")}
56+
${layout_declare_spec_const(C, "int", "conv2d_params_dilation_y", "1")}
57+
${layout_declare_spec_const(C, "int", "conv2d_params_kernel_size_x", "1")}
58+
${layout_declare_spec_const(C, "int", "conv2d_params_kernel_size_y", "1")}
59+
${layout_declare_spec_const(C, "int", "conv2d_params_in_channels_per_group", "1")}
60+
${layout_declare_spec_const(C, "int", "conv2d_params_out_channels_per_group", "1")}
61+
${layout_declare_spec_const(C, "int", "conv2d_params_K4_per_group", "1")}
62+
${layout_declare_spec_const(C, "int", "conv2d_params_K4", "1")}
63+
${layout_declare_spec_const(C, "int", "conv2d_params_K_per_group", "1")}
64+
${layout_declare_spec_const(C, "int", "conv2d_params_logical_K", "1")}
65+
${layout_declare_spec_const(C, "int", "conv2d_params_logical_K_per_group", "1")}
66+
${layout_declare_spec_const(C, "int", "conv2d_params_groups", "1")}
67+
68+
${layout_declare_spec_const(C, "int", "output_x", "1")}
69+
${layout_declare_spec_const(C, "int", "output_y", "1")}
70+
${layout_declare_spec_const(C, "int", "output_z", "1")}
71+
${layout_declare_spec_const(C, "int", "output_w", "1")}
72+
${layout_declare_spec_const(C, "int", "input_x", "1")}
73+
${layout_declare_spec_const(C, "int", "input_y", "1")}
74+
${layout_declare_spec_const(C, "int", "input_z", "1")}
75+
${layout_declare_spec_const(C, "int", "input_w", "1")}
76+
5177

5278
#include "conv2d_dw_q8_utils.glslh"
5379

5480
void main() {
5581
const int tid = int(gl_GlobalInvocationID.x);
82+
83+
const ivec4 output_sizes = ivec4(int(output_x), int(output_y), int(output_z), int(output_w));
84+
const ivec4 input_sizes = ivec4(int(input_x), int(input_y), int(input_z), int(input_w));
85+
86+
5687
Conv2dBlockExtents out_block_extents = make_block_extents(output_sizes);
5788

5889
Conv2dBlockIndex out_block_idx = linear_idx_to_block_idx(
@@ -64,23 +95,23 @@ void main() {
6495

6596
const int out_w = mul_4(out_block_idx.data.x);
6697
const int w_start =
67-
(out_w * conv2d_params.stride.x) - conv2d_params.padding.x;
68-
const int w_end = ((out_w + 3) * conv2d_params.stride.x) -
69-
conv2d_params.padding.x +
70-
(conv2d_params.kernel_size.x - 1) * conv2d_params.dilation.x;
98+
(out_w * conv2d_params_stride_x) - conv2d_params_padding_x;
99+
const int w_end = ((out_w + 3) * conv2d_params_stride_x) -
100+
conv2d_params_padding_x +
101+
(conv2d_params_kernel_size_x - 1) * conv2d_params_dilation_x;
71102

72103
Conv2dBlockExtents in_block_extents = make_block_extents(input_sizes);
73104

74105
const ivec4 input_zps = ivec4(pack_into_int32(ivec4(input_zp)));
75106
const vec4 weight_scales = vec4(t_weight_scales[out_block_idx.data.z]);
76107

77-
const int Kw4 = div_up_4(conv2d_params.kernel_size.x);
108+
const int Kw4 = div_up_4(conv2d_params_kernel_size_x);
78109

79110
FPOutBlock out_block;
80-
for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) {
111+
for (int ky = 0; ky < conv2d_params_kernel_size_y; ky++) {
81112
const int out_h = out_block_idx.data.y;
82-
const int h = out_h * conv2d_params.stride.y - conv2d_params.padding.y +
83-
ky * conv2d_params.dilation.y;
113+
const int h = out_h * conv2d_params_stride_y - conv2d_params_padding_y +
114+
ky * conv2d_params_dilation_y;
84115

85116
InputWindow1D input_window = load_input_window(
86117
w_start,
@@ -96,7 +127,7 @@ void main() {
96127
out_block_idx.data.z,
97128
ky,
98129
out_block_extents.data.z,
99-
conv2d_params.kernel_size.x,
130+
conv2d_params_kernel_size_x,
100131
Kw4,
101132
weight_scales);
102133

backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block.glslh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,27 +63,27 @@ void im2col_idx_to_input_tidx(
6363
TensorIndex4D output_tidx;
6464
unwrap_m(output_tidx, im2col_idx.row);
6565

66-
const int in_channels_per_group = conv2d_params.in_channels_per_group;
66+
const int in_channels_per_group = conv2d_params_in_channels_per_group;
6767
// Determine the corresponding position within the convolution window based
6868
// on the col index (more specifically, the col index within the group)
6969
const int channel_within_group =
7070
im2col_idx.col_idx_in_group % in_channels_per_group;
7171
const int kernel_x = (im2col_idx.col_idx_in_group / in_channels_per_group) %
72-
conv2d_params.kernel_size.x;
72+
conv2d_params_kernel_size_x;
7373
const int kernel_y = im2col_idx.col_idx_in_group /
74-
(in_channels_per_group * conv2d_params.kernel_size.x);
74+
(in_channels_per_group * conv2d_params_kernel_size_x);
7575

7676
// Calculate the actual input channel index
7777
const int channel_idx =
78-
im2col_idx.group_idx * conv2d_params.in_channels_per_group +
78+
im2col_idx.group_idx * conv2d_params_in_channels_per_group +
7979
channel_within_group;
8080

8181
// Calculate corresponding input coordinates based on output position
8282
// associated with the row index.
83-
const int input_y = int(output_tidx.data.y * conv2d_params.stride.y) -
84-
int(conv2d_params.padding.y) + int(kernel_y * conv2d_params.dilation.y);
85-
const int input_x = int(output_tidx.data.x * conv2d_params.stride.x) -
86-
int(conv2d_params.padding.x) + int(kernel_x * conv2d_params.dilation.x);
83+
const int input_y = int(output_tidx.data.y * conv2d_params_stride_y) -
84+
int(conv2d_params_padding_y) + int(kernel_y * conv2d_params_dilation_y);
85+
const int input_x = int(output_tidx.data.x * conv2d_params_stride_x) -
86+
int(conv2d_params_padding_x) + int(kernel_x * conv2d_params_dilation_x);
8787

8888
input_tidx.data = ivec4(input_x, input_y, channel_idx, output_tidx.data.w);
8989
}

backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_load.glslh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ void load_im2col_block_fast(
6464
// Due to the assumption that in_channels_per_group % 4 == 0, it is
6565
// guaranteed that the next 4 columns (including this one) is part of the
6666
// same group.
67-
im2col_idx.group_idx = im2col_idx.col / conv2d_params.K_per_group;
68-
im2col_idx.col_idx_in_group = im2col_idx.col % conv2d_params.K_per_group;
67+
im2col_idx.group_idx = im2col_idx.col / conv2d_params_K_per_group;
68+
im2col_idx.col_idx_in_group = im2col_idx.col % conv2d_params_K_per_group;
6969

7070
[[unroll]] for (int m_off = 0; m_off < 4; ++m_off) {
7171
if (im2col_idx.row >= M) {
@@ -98,9 +98,9 @@ void load_im2col_block_slow(
9898
im2col_idx_base.col = mul_4(k4);
9999
im2col_idx_base.row = mul_4(m4);
100100

101-
im2col_idx_base.group_idx = im2col_idx_base.col / conv2d_params.K_per_group;
101+
im2col_idx_base.group_idx = im2col_idx_base.col / conv2d_params_K_per_group;
102102
im2col_idx_base.col_idx_in_group =
103-
im2col_idx_base.col % conv2d_params.K_per_group;
103+
im2col_idx_base.col % conv2d_params_K_per_group;
104104

105105
[[unroll]] for (int m_off = 0; m_off < 4; ++m_off) {
106106
[[unroll]] for (int k_off = 0; k_off < 4; ++k_off) {
@@ -109,7 +109,7 @@ void load_im2col_block_slow(
109109
im2col_idx.col_idx_in_group += k_off;
110110

111111
// bounds checking
112-
if (im2col_idx.col_idx_in_group >= conv2d_params.logical_K_per_group ||
112+
if (im2col_idx.col_idx_in_group >= conv2d_params_logical_K_per_group ||
113113
im2col_idx.row >= M) {
114114
block.data[m_off][k_off] = T(0);
115115
continue;
@@ -129,7 +129,7 @@ void load_im2col_block(
129129
const int m4,
130130
const int logical_K,
131131
const int M) {
132-
if (mod_4(conv2d_params.in_channels_per_group) == 0) {
132+
if (mod_4(conv2d_params_in_channels_per_group) == 0) {
133133
load_im2col_block_fast(block, k4, m4, logical_K, M);
134134
} else {
135135
load_im2col_block_slow(block, k4, m4, logical_K, M);

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_q8ta_q8csw_q8to_tiled.glsl

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ ${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_arra
4242
${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)}
4343
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)}
4444

45-
${layout_declare_ubo(B, "ivec4", "output_sizes")}
46-
${layout_declare_ubo(B, "ivec4", "input_sizes")}
47-
${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")}
45+
//${layout_declare_ubo(B, "ivec4", "output_sizes")}
46+
//${layout_declare_ubo(B, "ivec4", "input_sizes")}
47+
//${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")}
4848

4949
layout(push_constant) uniform restrict Block {
5050
float input_scale;
@@ -56,6 +56,32 @@ layout(push_constant) uniform restrict Block {
5656
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
5757

5858
${layout_declare_spec_const(C, "int", "apply_bias", "1")}
59+
${layout_declare_spec_const(C, "int", "conv2d_params_stride_x", "1")}
60+
${layout_declare_spec_const(C, "int", "conv2d_params_stride_y", "1")}
61+
${layout_declare_spec_const(C, "int", "conv2d_params_padding_x", "1")}
62+
${layout_declare_spec_const(C, "int", "conv2d_params_padding_y", "1")}
63+
${layout_declare_spec_const(C, "int", "conv2d_params_dilation_x", "1")}
64+
${layout_declare_spec_const(C, "int", "conv2d_params_dilation_y", "1")}
65+
${layout_declare_spec_const(C, "int", "conv2d_params_kernel_size_x", "1")}
66+
${layout_declare_spec_const(C, "int", "conv2d_params_kernel_size_y", "1")}
67+
${layout_declare_spec_const(C, "int", "conv2d_params_in_channels_per_group", "1")}
68+
${layout_declare_spec_const(C, "int", "conv2d_params_out_channels_per_group", "1")}
69+
${layout_declare_spec_const(C, "int", "conv2d_params_K4_per_group", "1")}
70+
${layout_declare_spec_const(C, "int", "conv2d_params_K4", "1")}
71+
${layout_declare_spec_const(C, "int", "conv2d_params_K_per_group", "1")}
72+
${layout_declare_spec_const(C, "int", "conv2d_params_logical_K", "1")}
73+
${layout_declare_spec_const(C, "int", "conv2d_params_logical_K_per_group", "1")}
74+
${layout_declare_spec_const(C, "int", "conv2d_params_groups", "1")}
75+
76+
${layout_declare_spec_const(C, "int", "output_x", "1")}
77+
${layout_declare_spec_const(C, "int", "output_y", "1")}
78+
${layout_declare_spec_const(C, "int", "output_z", "1")}
79+
${layout_declare_spec_const(C, "int", "output_w", "1")}
80+
${layout_declare_spec_const(C, "int", "input_x", "1")}
81+
${layout_declare_spec_const(C, "int", "input_y", "1")}
82+
${layout_declare_spec_const(C, "int", "input_z", "1")}
83+
${layout_declare_spec_const(C, "int", "input_w", "1")}
84+
5985

6086
#include "conv2d_int8_input_tile_load.glslh"
6187
#include "linear_int8_weight_tile_load.glslh"
@@ -72,6 +98,9 @@ void main() {
7298
output_block_idx.data.x = int(gl_GlobalInvocationID.y) * TILE_M4;
7399
output_block_idx.data.y = int(gl_GlobalInvocationID.z);
74100

101+
const ivec4 output_sizes = ivec4(int(output_x), int(output_y), int(output_z), int(output_w));
102+
const ivec4 input_sizes = ivec4(int(input_x), int(input_y), int(input_z), int(input_w));
103+
75104
Conv2dBlockExtents output_block_extents = make_block_extents(output_sizes);
76105
if (block_idx_out_of_bounds(output_block_idx, output_block_extents)) {
77106
return;
@@ -88,7 +117,7 @@ void main() {
88117
Int8InputTileIndex input_idx = make_initial_int8_input_tile_index(
89118
output_block_idx, input_block_extents);
90119

91-
for (int k4 = 0; k4 < conv2d_params.K4_per_group; k4++) {
120+
for (int k4 = 0; k4 < conv2d_params_K4_per_group; k4++) {
92121
load_packed_int8_input_tile(int8_input_tile, input_idx);
93122

94123
load_int8_weight_tile(

backends/vulkan/runtime/graph/ops/glsl/conv2d_q8_utils.glslh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ void perform_conv1d(
106106
const ivec4 weight_block,
107107
const int kx) {
108108
[[unroll]] for (int out_w = 0; out_w < 4; ++out_w) {
109-
const int window_i = out_w * conv2d_params.stride.x + kx;
109+
const int window_i = out_w * conv2d_params_stride_x + kx;
110110
[[unroll]] for (int out_c = 0; out_c < 4; ++out_c) {
111111
accum.data[out_w][0][out_c] = dotPacked4x8AccSatEXT(
112112
input_window.data[window_i],

backends/vulkan/runtime/graph/ops/glsl/conv2d_q8csw_linear_tiled.glsl

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,40 @@ ${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, i
3939
${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)}
4040
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)}
4141

42-
${layout_declare_ubo(B, "ivec4", "output_sizes")}
43-
${layout_declare_ubo(B, "ivec4", "input_sizes")}
44-
${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")}
42+
//${layout_declare_ubo(B, "ivec4", "output_sizes")}
43+
//${layout_declare_ubo(B, "ivec4", "input_sizes")}
44+
//${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")}
4545

4646
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4747

4848
${layout_declare_spec_const(C, "int", "apply_bias", "1")}
49+
${layout_declare_spec_const(C, "int", "conv2d_params_stride_x", "1")}
50+
${layout_declare_spec_const(C, "int", "conv2d_params_stride_y", "1")}
51+
${layout_declare_spec_const(C, "int", "conv2d_params_padding_x", "1")}
52+
${layout_declare_spec_const(C, "int", "conv2d_params_padding_y", "1")}
53+
${layout_declare_spec_const(C, "int", "conv2d_params_dilation_x", "1")}
54+
${layout_declare_spec_const(C, "int", "conv2d_params_dilation_y", "1")}
55+
${layout_declare_spec_const(C, "int", "conv2d_params_kernel_size_x", "1")}
56+
${layout_declare_spec_const(C, "int", "conv2d_params_kernel_size_y", "1")}
57+
${layout_declare_spec_const(C, "int", "conv2d_params_in_channels_per_group", "1")}
58+
${layout_declare_spec_const(C, "int", "conv2d_params_out_channels_per_group", "1")}
59+
${layout_declare_spec_const(C, "int", "conv2d_params_K4_per_group", "1")}
60+
${layout_declare_spec_const(C, "int", "conv2d_params_K4", "1")}
61+
${layout_declare_spec_const(C, "int", "conv2d_params_K_per_group", "1")}
62+
${layout_declare_spec_const(C, "int", "conv2d_params_logical_K", "1")}
63+
${layout_declare_spec_const(C, "int", "conv2d_params_logical_K_per_group", "1")}
64+
${layout_declare_spec_const(C, "int", "conv2d_params_groups", "1")}
65+
66+
${layout_declare_spec_const(C, "int", "output_x", "1")}
67+
${layout_declare_spec_const(C, "int", "output_y", "1")}
68+
${layout_declare_spec_const(C, "int", "output_z", "1")}
69+
${layout_declare_spec_const(C, "int", "output_w", "1")}
70+
${layout_declare_spec_const(C, "int", "input_x", "1")}
71+
${layout_declare_spec_const(C, "int", "input_y", "1")}
72+
${layout_declare_spec_const(C, "int", "input_z", "1")}
73+
${layout_declare_spec_const(C, "int", "input_w", "1")}
74+
75+
4976

5077
#include "linear_fp_input_tile_load.glslh"
5178
#include "linear_int8_weight_tile_load.glslh"
@@ -60,6 +87,10 @@ void main() {
6087
const int out_tile_x = int(gl_GlobalInvocationID.x);
6188
const int out_tile_y = int(gl_GlobalInvocationID.y);
6289

90+
const ivec4 output_sizes = ivec4(int(output_x), int(output_y), int(output_z), int(output_w));
91+
const ivec4 input_sizes = ivec4(int(input_x), int(input_y), int(input_z), int(input_w));
92+
93+
6394
const int n = int(out_tile_x * TILE_N);
6495
const int m = int(out_tile_y * TILE_M);
6596

@@ -75,10 +106,10 @@ void main() {
75106
return;
76107
}
77108

78-
const int group_idx = n / conv2d_params.out_channels_per_group;
79-
const int input_k4_offset = conv2d_params.K4_per_group * group_idx;
109+
const int group_idx = n / conv2d_params_out_channels_per_group;
110+
const int input_k4_offset = conv2d_params_K4_per_group * group_idx;
80111

81-
const int K4 = conv2d_params.K4;
112+
const int K4 = conv2d_params_K4;
82113
const int N4 = div_up_4(N);
83114

84115
FPOutTile out_tile;
@@ -90,13 +121,13 @@ void main() {
90121
const bool dont_check_bounds = (M - m) >= TILE_M;
91122

92123
if (dont_check_bounds) {
93-
for (int k4 = 0; k4 < conv2d_params.K4_per_group; k4++) {
124+
for (int k4 = 0; k4 < conv2d_params_K4_per_group; k4++) {
94125
load_input_tile_no_checks(in_tile, k4 + input_k4_offset, m, K4, M);
95126
load_int8_weight_tile(int8_weight_tile, n4, k4, N4);
96127
fp_accumulate_with_int8_weight(out_tile, in_tile, int8_weight_tile);
97128
}
98129
} else {
99-
for (int k4 = 0; k4 < conv2d_params.K4_per_group; k4++) {
130+
for (int k4 = 0; k4 < conv2d_params_K4_per_group; k4++) {
100131
load_input_tile_with_checks(in_tile, k4 + input_k4_offset, m, K4, M);
101132
load_int8_weight_tile(int8_weight_tile, n4, k4, N4);
102133
fp_accumulate_with_int8_weight(out_tile, in_tile, int8_weight_tile);

0 commit comments

Comments
 (0)