-
Notifications
You must be signed in to change notification settings - Fork 873
Expand file tree
/
Copy pathquantized_convolution.py
More file actions
346 lines (289 loc) · 12.4 KB
/
quantized_convolution.py
File metadata and controls
346 lines (289 loc) · 12.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import cast, List, Optional
import executorch.backends.vulkan.utils as utils
import torch
from executorch.backends.transforms.utils import (
create_constant_placeholder,
get_param_tensor,
)
from executorch.backends.vulkan.patterns.pattern_registry import (
PatternMatch,
register_pattern_detector,
register_pattern_replacement,
)
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export.graph_signature import InputKind
class QuantizedConvolutionMatch(PatternMatch):
def __init__(self, conv_node: torch.fx.Node) -> None: # noqa: C901
self.anchor_node = conv_node
self.match_found = False
self.all_nodes = [self.anchor_node]
# Determine if this is a transposed convolution
self.transposed = False
self.output_padding = [0, 0]
if conv_node.target == exir_ops.edge.aten.convolution.default:
transposed_flag = conv_node.args[6] if len(conv_node.args) > 6 else False
if transposed_flag:
self.transposed = True
self.output_padding = (
cast(List[int], conv_node.args[7])
if len(conv_node.args) > 7
else [0, 0]
)
# Extract convolution parameters
self.stride = conv_node.args[3] if len(conv_node.args) > 3 else [1, 1]
self.padding = conv_node.args[4] if len(conv_node.args) > 4 else [0, 0]
self.dilation = conv_node.args[5] if len(conv_node.args) > 5 else [1, 1]
self.groups = conv_node.args[8] if len(conv_node.args) > 8 else 1
# Transposed conv only supported with dilation=[1,1]
if self.transposed and cast(List[int], self.dilation) != [1, 1]:
return
const_node, arg_chain = utils.trace_args_until_placeholder(
self.anchor_node.args[1]
)
# weight is not a constant tensor - no match
if const_node is None:
return
dequantize_weight_node = None
# Search for a dequantize node in the arg chain of weight
for node in arg_chain:
if isinstance(node, torch.fx.Node) and utils.is_dequant_node(node):
dequantize_weight_node = node
# weight is not quantized - no match
if dequantize_weight_node is None:
return
self.weight_node = const_node
self.dequantize_weight_node = dequantize_weight_node
self.all_nodes.extend(arg_chain)
# For transposed conv, verify per-channel quantization is on the OC dimension.
# Transposed weight shape is (IC, OC_per_group, KH, KW), so per-OC quantization
# should be on axis=1. If axis=0, that's per-IC which is not supported.
if self.transposed and utils.is_dequant_per_channel_node(
self.dequantize_weight_node
):
quant_axis = self.dequantize_weight_node.args[3]
if quant_axis != 1:
return
# Identify weight quantization parameter nodes
self.weight_scales_node, arg_chain = utils.trace_args_until_placeholder(
self.dequantize_weight_node.args[1]
)
assert self.weight_scales_node is not None
self.all_nodes.extend(arg_chain)
self.weight_zeros_node, arg_chain = utils.trace_args_until_placeholder(
self.dequantize_weight_node.args[2]
)
assert self.weight_zeros_node is not None
self.all_nodes.extend(arg_chain)
# Identify output node
self.output_node = self.anchor_node
out_channels = self.output_node.meta["val"].shape[-3]
# The implementation requires that for non-depthwise grouped convolutions, a
# group does not cross the texel boundary. The output channels per group must be
# a multiple of 4. If this is not true, then don't match the pattern.
if (self.groups > 1 and self.groups < out_channels) and (
out_channels / self.groups
) % 4 != 0:
return
# Identify bias node, if applicable
self.bias_node = None
if len(self.anchor_node.args) > 2 and self.anchor_node.args[2] is not None:
self.bias_node, arg_chain = utils.trace_args_until_placeholder(
self.anchor_node.args[2]
)
if self.bias_node is not None:
self.all_nodes.extend(arg_chain)
# Identify input node
primary_input_node = self.anchor_node.args[0]
assert isinstance(primary_input_node, torch.fx.Node)
# Argument must be a dequant node for static quantization
if not utils.is_dequant_node(primary_input_node):
return
self.dequantize_input_node = primary_input_node
self.quantize_input_node = self.dequantize_input_node.args[0]
self.input_scales_node = self.dequantize_input_node.args[1]
self.input_zeros_node = self.dequantize_input_node.args[2]
self.all_nodes.extend([self.dequantize_input_node])
# The convolution output must have only one user; it will be either a relu node
# or a dequantize node.
if len(self.output_node.users) != 1:
return
cur_node = list(self.output_node.users)[0]
self.relu_node = None
if cur_node.target == exir_ops.edge.aten.relu.default:
self.relu_node = cur_node
cur_node = list(cur_node.users)[0]
if not utils.is_quant_node(cur_node):
return
self.quantize_output_node = cur_node
self.output_scales_node = self.quantize_output_node.args[1]
self.output_zeros_node = self.quantize_output_node.args[2]
self.match_found = True
convolution_anchor_nodes = {
exir_ops.edge.aten.conv2d.default,
exir_ops.edge.aten.convolution.default,
}
@register_pattern_detector("quantized_convolution")
def find_quantized_convolution_patterns(
node: torch.fx.Node,
) -> Optional[QuantizedConvolutionMatch]:
if node.target not in convolution_anchor_nodes:
return None
matched_pattern = QuantizedConvolutionMatch(node)
if matched_pattern.match_found:
return matched_pattern
return None
##
## Pattern Replacement
##
@register_pattern_replacement("quantized_convolution")
def make_q8ta_conv2d_custom_op(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
match: QuantizedConvolutionMatch,
):
weight_tensor = get_param_tensor(ep, match.weight_node)
assert weight_tensor is not None
assert match.weight_scales_node is not None
weight_scales_tensor = get_param_tensor(ep, match.weight_scales_node)
assert weight_scales_tensor is not None
assert match.weight_zeros_node is not None
weight_zeros_tensor = get_param_tensor(ep, match.weight_zeros_node)
assert weight_zeros_tensor is not None
bias_tensor = None
if match.bias_node is not None:
bias_tensor = get_param_tensor(ep, match.bias_node)
assert bias_tensor is not None
if match.transposed:
# Transposed conv weight shape: (IC, OC_per_group, H, W)
IC, OC_per_group, H, W = weight_tensor.shape
OC = OC_per_group * match.groups
IC_per_group = IC // match.groups
# Reshape to (OC, H*W*IC_per_group) matrix format for Im2Col-based
# transposed convolution.
# (IC, OC_per_group, H, W) ->
# (groups, IC_per_group, OC_per_group, H, W) ->
# (groups, OC_per_group, H, W, IC_per_group) ->
# (OC, H*W*IC_per_group)
weight_tensor = (
weight_tensor.reshape(match.groups, IC_per_group, OC_per_group, H, W)
.permute(0, 2, 3, 4, 1)
.contiguous()
.reshape(OC, H * W * IC_per_group)
.contiguous()
)
else:
OC, IC_per_group, H, W = weight_tensor.shape
is_depthwise_conv = (
not match.transposed and IC_per_group == 1 and match.groups == OC
)
if is_depthwise_conv:
assert OC % 4 == 0, "depthwise conv requires that OC is divisible by 4"
# Depthwise convs use a specialized layout; the weight tensor is reshaped to
# (H, W, OC)
weight_tensor = (
weight_tensor.permute(2, 3, 1, 0).contiguous().view(H, W, OC).contiguous()
)
elif not match.transposed:
# Reshape weight tensor from (OC, IC_per_group, H, W) to (OC, H * W * IC_per_group)
# (i.e. matrix format). This prepares the weights for Im2Col-based convolution.
weight_tensor = (
weight_tensor.permute(0, 2, 3, 1)
.contiguous()
.view(OC, H * W * IC_per_group)
.contiguous()
)
# Need to make sure that OC dim is a multiple of 4 so that data load/stores are well
# aligned with texel boundaries. Add padding to align to the next multiple of 4 if
# needed.
utils.align_width_and_update_state_dict(
ep, match.weight_node, weight_tensor, force_update=True
)
utils.align_width_and_update_state_dict(
ep, match.weight_scales_node, weight_scales_tensor
)
if bias_tensor is not None:
utils.align_width_and_update_state_dict(ep, match.bias_node, bias_tensor)
first_graph_node = list(graph_module.graph.nodes)[0]
with graph_module.graph.inserting_before(first_graph_node):
qweight_tensor_name = utils.get_tensor_name(ep, match.weight_node)
# Pre-compute the weight sums which are needed to apply activation zero point
# when using integer accumulation. Sum all weight elements per output channel.
if is_depthwise_conv:
# weight_tensor shape is (H, W, OC); sum over spatial dims (H, W)
sum_per_output_channel = (
weight_tensor.sum(dim=(0, 1)).to(torch.int32).contiguous()
)
else:
# weight_tensor shape is (OC, H*W*IC_per_group); sum over dim 1
sum_per_output_channel = (
weight_tensor.sum(dim=1).to(torch.int32).contiguous()
)
# Pad weight sums to align OC to multiple of 4, matching the alignment
# applied to weight, weight_scales, and bias above. Without this, the
# GPU shader would read out-of-bounds when OC is not a multiple of 4.
oc = sum_per_output_channel.shape[0]
if oc % 4 != 0:
num_padding = 4 - (oc % 4)
sum_per_output_channel = torch.nn.functional.pad(
sum_per_output_channel, (0, num_padding)
).contiguous()
sums_name = qweight_tensor_name + "_sums"
# Sanitize the name
sums_name = sums_name.replace(".", "_")
weight_sums_node = create_constant_placeholder(
exp_program=ep,
graph=graph_module.graph,
kind=InputKind.CONSTANT_TENSOR,
name=sums_name,
data=sum_per_output_channel,
)
is_pointwise_conv = (
H == 1
and W == 1
and list(match.stride) == [1, 1]
and list(match.dilation) == [1, 1]
and list(match.padding) == [0, 0]
)
with graph_module.graph.inserting_before(match.output_node):
if match.transposed:
op_target = exir_ops.edge.et_vk.q8ta_conv2d_transposed.default
elif is_depthwise_conv:
op_target = exir_ops.edge.et_vk.q8ta_conv2d_dw.default
elif is_pointwise_conv:
op_target = exir_ops.edge.et_vk.q8ta_conv2d_pw.default
else:
op_target = exir_ops.edge.et_vk.q8ta_conv2d.default
op_args = (
match.quantize_input_node,
match.input_scales_node,
match.input_zeros_node,
match.weight_node,
weight_sums_node,
match.weight_scales_node,
match.output_scales_node,
match.output_zeros_node,
match.bias_node,
[H, W],
match.stride,
match.padding,
)
if match.transposed:
op_args = op_args + (match.output_padding,)
op_args = op_args + (
match.dilation,
match.groups,
"relu" if match.relu_node is not None else "none",
)
qconv_node = graph_module.graph.create_node(
"call_function",
op_target,
args=op_args,
)
qconv_node.meta["val"] = match.output_node.meta["val"]
match.quantize_output_node.replace_all_uses_with(qconv_node)