Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 1 addition & 44 deletions backends/qualcomm/_passes/annotate_quant_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,17 @@
from typing import Any, Dict

import torch
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
from executorch.backends.qualcomm.builders.utils import get_parameter
from executorch.backends.qualcomm.utils.constants import (
QCOM_AXIS,
QCOM_BLOCK_SIZE,
QCOM_DTYPE,
QCOM_ENCODING,
QCOM_QUANT_ATTRS,
QCOM_QUANT_MAX,
QCOM_QUANT_MIN,
QCOM_REQUANTIZE,
QCOM_SCALE,
QCOM_SCALES,
QCOM_ZERO_POINT,
QCOM_ZERO_POINTS,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import dq_ops, get_quant_attrs, q_ops
Expand Down Expand Up @@ -101,43 +96,9 @@ def _annotate_requant(self, n):
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})
n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs

# Dequant all the fold_quant parameters back to fp32.
# If an operation is not supported by QNN and got fallback, it will expect a fp32 param.
def _dequant_fold_params(self, n, quant_attrs, param):
if quant_attrs[QCOM_ENCODING] in [
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default
]:
dim, axis = param.dim(), quant_attrs[QCOM_AXIS]
scales = self._expand(quant_attrs[QCOM_SCALES], dim, axis)
offsets = self._expand(quant_attrs[QCOM_ZERO_POINTS], dim, axis)
param = param.sub(offsets).mul(scales).to(torch.float32).contiguous()
elif quant_attrs[QCOM_ENCODING] in [
exir_ops.edge.pt2e_quant.dequantize_affine.default
]:
param = torch.ops.pt2e_quant.dequantize_affine(
param,
block_size=quant_attrs[QCOM_BLOCK_SIZE],
scale=quant_attrs[QCOM_SCALE],
zero_point=quant_attrs[QCOM_ZERO_POINT],
input_dtype=quant_attrs[QCOM_DTYPE],
quant_min=quant_attrs[QCOM_QUANT_MIN],
quant_max=quant_attrs[QCOM_QUANT_MAX],
output_dtype=torch.float32,
)
else:
scale = quant_attrs[QCOM_SCALE]
offset = quant_attrs[QCOM_ZERO_POINT]
param = param.sub(offset).mul(scale).to(torch.float32).contiguous()

set_parameter(param, n.args[0], self.edge_program)
n.args[0].meta["val"] = param

def _annotate_quant_attrs(
self, graph_module: torch.fx.GraphModule
) -> torch.fx.GraphModule:
# Keep track of const params that has been dequant, so it does not get
# dequant multiple times if the const param has more than 1 user
visited_const_param = set()
for n in graph_module.graph.nodes:
self._annotate_requant(n)
# With fold_quant enabled, check if the input of dq op is quantized param.
Expand All @@ -149,10 +110,6 @@ def _annotate_quant_attrs(
quant_attrs = get_quant_attrs(self.edge_program, n)
self._annotate_source_nodes(n, quant_attrs)

if param is not None and n.args[0] not in visited_const_param:
visited_const_param.add(n.args[0])
self._dequant_fold_params(n, quant_attrs, param)

return graph_module

def call(self, graph_module: torch.fx.GraphModule):
Expand Down
125 changes: 94 additions & 31 deletions backends/qualcomm/_passes/convert_conv1d_to_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

from .utils import copy_meta
Expand All @@ -23,16 +21,43 @@ class ConvertConv1dToConv2d(ExportPass):
def __init__(self, edge_program: torch.export.ExportedProgram):
super(ConvertConv1dToConv2d, self).__init__()
self.edge_program = edge_program
self.conv_op_map = {
torch.ops.aten.conv1d.default: torch.ops.aten.conv2d.default,
torch.ops.aten.conv_transpose1d.default: torch.ops.aten.conv_transpose2d.input,
}

def append_qdq(
self,
graph_module: torch.fx.GraphModule,
node: torch.fx.Node,
qdq_node: torch.fx.Node,
):
q_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
if qdq_node.target not in {q_op, dq_op}:
return node

with graph_module.graph.inserting_after(node):
q_args = (node, *qdq_node.args[1:])
q_node = graph_module.graph.create_node("call_function", q_op, q_args)
q_node.meta = copy_meta(node.meta)
q_node.meta["val"] = q_node.meta["val"].to(q_args[-1])
with graph_module.graph.inserting_after(q_node):
dq_args = (q_node, *qdq_node.args[1:])
dq_node = graph_module.graph.create_node(
"call_function", dq_op, dq_args
)
dq_node.meta = copy_meta(node.meta)

return dq_node

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
conv_op = exir_ops.edge.aten.convolution.default
for node in graph.nodes:
if node.target == conv_op and node.meta["val"].dim() == 3:

if node.target in self.conv_op_map:
input_node = node.args[0]
with graph_module.graph.inserting_after(input_node):
unsqueeze_op = exir_ops.edge.aten.unsqueeze_copy.default
unsqueeze_op = torch.ops.aten.unsqueeze_copy.default
unsqueeze_node = graph.create_node(
"call_function",
unsqueeze_op,
Expand All @@ -44,52 +69,88 @@ def call(self, graph_module: torch.fx.GraphModule):
unsqueeze_node.meta = copy_meta(
input_node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
)
qdq_node_after_unsqueeze = self.append_qdq(
graph_module=graph_module,
node=unsqueeze_node,
qdq_node=input_node,
)

with graph_module.graph.inserting_after(unsqueeze_node):

filter_node = node.args[1]
with graph_module.graph.inserting_after(qdq_node_after_unsqueeze):
filter_arg = node.args[1]
filter_node = (
filter_arg
if filter_arg.op == "placeholder"
else node.args[1].args[0]
)
filter_node.meta["val"] = (
filter_node.meta["val"].unsqueeze(2).contiguous()
)
filter_tensor = get_parameter(filter_node, self.edge_program)
# Ensure tensor is nn.Parameter type, so program does not fail during edge_program._validate()
filter_tensor = nn.Parameter(filter_tensor.unsqueeze(2))
set_parameter(filter_tensor, filter_node, self.edge_program)
filter_tensor = get_parameter(
filter_node, self.edge_program
).unsqueeze(2)
set_parameter(
(
torch.nn.Parameter(filter_tensor)
if filter_tensor.dtype == torch.float
else filter_tensor
),
filter_node,
self.edge_program,
)

num_args = len(node.args)
bias_node = node.args[2]
stride = [1] + node.args[3]
padding = [0] + node.args[4]
dilation = [1] + node.args[5]
transpose = node.args[6]
output_padding = [0] + node.args[7]
groups = node.args[8]

conv2d_node = graph.create_node(
"call_function",
conv_op,
(
unsqueeze_node,
filter_node,
stride = [1] + node.args[3] if num_args > 3 else [1, 1]
padding = [0] + node.args[4] if num_args > 4 else [0, 0]
if node.target == torch.ops.aten.conv1d.default:
dilation = [1] + node.args[5] if num_args > 5 else [1, 1]
groups = node.args[6] if num_args > 5 else 1
conv_args = (
qdq_node_after_unsqueeze,
node.args[1],
bias_node,
stride,
padding,
dilation,
transpose,
groups,
)
else:
output_padding = (
[0] + node.args[5] if num_args > 5 else [0, 0]
)
groups = node.args[6] if num_args > 6 else 1
dilation = [1] + node.args[7] if num_args > 7 else [1, 1]
conv_args = (
qdq_node_after_unsqueeze,
node.args[1],
bias_node,
stride,
padding,
output_padding,
groups,
),
dilation,
)
conv2d_node = graph.create_node(
"call_function",
self.conv_op_map[node.target],
conv_args,
)
conv2d_node.meta = copy_meta(
node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
)
qdq_node_after_conv2d = self.append_qdq(
graph_module=graph_module,
node=conv2d_node,
qdq_node=list(node.users)[0],
)

with graph_module.graph.inserting_after(conv2d_node):
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
with graph_module.graph.inserting_after(qdq_node_after_conv2d):
squeeze_op = torch.ops.aten.squeeze_copy.dims
squeeze_node = graph.create_node(
"call_function",
squeeze_op,
(
conv2d_node,
qdq_node_after_conv2d,
[2],
),
)
Expand All @@ -102,8 +163,10 @@ def call(self, graph_module: torch.fx.GraphModule):
QCOM_REQUANTIZE
]
conv2d_node.meta.pop(QCOM_REQUANTIZE, None)

for user in node.users.copy():
user.replace_input_with(node, squeeze_node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import dead_code_elimination_pass

from .utils import dq_ops


class ExpandBroadcastTensorShape(ExportPass):
"""
Expand Down Expand Up @@ -45,9 +47,13 @@ def traverse_broadcast_node(self, graph_module: torch.fx.GraphModule):
exir_ops.edge.aten.view_copy.default,
(arg, tuple(new_rank)),
)
# try skip dq_ops to get correct param node if applicable
arg_meta = (
arg.args[0].meta if arg.target in dq_ops else arg.meta
)
# meta needs to be copied elementwisely for fake-tensor
# to be updated correctly and not affect meta of arg
for k, v in arg.meta.items():
for k, v in arg_meta.items():
reshape_node.meta[k] = v
reshape_node.meta["val"] = reshape_node.meta["val"].reshape(
new_rank
Expand Down
30 changes: 24 additions & 6 deletions backends/qualcomm/_passes/fold_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from executorch.backends.qualcomm.builders.utils import is_parameter
from executorch.backends.qualcomm.utils.constants import QCOM_BYPASS_NODE
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import dead_code_elimination_pass
Expand All @@ -16,23 +18,38 @@ class FoldQDQ(ExportPass):
Erase QDQ pattern.
"""

def __init__(self):
def __init__(self, edge_program: torch.export.ExportedProgram, force_fold=False):
super(FoldQDQ, self).__init__()
self.edge_program = edge_program
self.force_fold = force_fold

def _fold(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
def _annotate_bypass(self, node):
node.meta[QCOM_BYPASS_NODE] = True
for arg in node.args:
if isinstance(arg, torch.fx.Node) and arg.op == "call_function":
self._annotate_bypass(arg)

def _fold_dq(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
# remove dq
for n in graph_module.graph.nodes:
user_list = list(n.users.keys())
if n.target not in dq_ops:
continue
for user_n in user_list:
user_n.replace_input_with(n, n.args[0])
graph_module.graph.erase_node(n)

# skip parameters & buffers
if not self.force_fold and is_parameter(n.args[0], self.edge_program):
self._annotate_bypass(n)
else:
for user_n in user_list:
user_n.replace_input_with(n, n.args[0])
graph_module.graph.erase_node(n)

def _fold_q(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
# remove q
for n in graph_module.graph.nodes:
if n.target not in q_ops:
continue

to_be_removed = [n]
source_n = n.args[0]

Expand All @@ -57,7 +74,8 @@ def _fold(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
graph_module.graph.erase_node(n)

def call(self, graph_module: torch.fx.GraphModule):
self._fold(graph_module)
self._fold_dq(graph_module)
self._fold_q(graph_module)
graph_module.recompile()
dead_code_elimination_pass(graph_module)
return PassResult(graph_module, True)
4 changes: 4 additions & 0 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,17 @@ def transform_for_export_pipeline(self, exported_program: ExportedProgram):
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
self.add_pass(DecomposeExpM1())
# this pass will rewrite state_dict, it needs to be accomplished before
# to_edge_transform_and_lower
self.add_pass(ConvertConv1dToConv2d(exported_program))
self.add_pass(ConvertSquareToPow())
self.add_pass(LiftConstantScalarOperands())
self._transform(exported_program.graph_module)
ep = lift_constant_tensor_pass(exported_program)
return ep

def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram):
self.add_pass(FoldQDQ(exported_program, force_fold=True))
self.add_pass(InsertRequantize())
self.add_pass(InsertIOQDQ(exported_program))
self.add_pass(LayoutTransform(exported_program, insert_permute=True))
Expand Down
7 changes: 4 additions & 3 deletions backends/qualcomm/_passes/replace_index_put_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def call(self, graph_module: torch.fx.GraphModule):
copy_node := list(node.users)[0]
) and copy_node.target == exir_ops.edge.aten.copy.default:
m_buffer_node = copy_node.args[0]
bad_frozen_node = node.args[0]
dq_node = node.args[0]
bad_frozen_node = dq_node.args[0]
if QCOM_QUANT_ATTRS in bad_frozen_node.meta:
m_buffer_node.meta[QCOM_QUANT_ATTRS] = bad_frozen_node.meta[
QCOM_QUANT_ATTRS
Expand All @@ -43,8 +44,8 @@ def call(self, graph_module: torch.fx.GraphModule):
m_buffer_node.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING]
]
)
with graph.inserting_after(bad_frozen_node):
node.replace_input_with(bad_frozen_node, m_buffer_node)
with graph.inserting_after(dq_node):
node.replace_input_with(dq_node, m_buffer_node)
else:
continue

Expand Down
1 change: 0 additions & 1 deletion backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def get_passes_dependency_for_capture_program():
AnnotateStack: [RemoveRedundancy],
AnnotateUnbind: [RemoveRedundancy],
ConvertBmmToMatmul: [RecomposePixelUnshuffle],
ConvertConv1dToConv2d: [FoldQDQ],
ConvertUpsampleBicubicWithBilinear: [RemoveRedundancy],
DecomposeAny: [RemoveRedundancy],
DecomposeLinalgVectorNorm: [RemoveRedundancy],
Expand Down
Loading
Loading