diff --git a/backends/qualcomm/_passes/annotate_quant_attrs.py b/backends/qualcomm/_passes/annotate_quant_attrs.py index b4f14fc28cd..d9ef9cb691d 100644 --- a/backends/qualcomm/_passes/annotate_quant_attrs.py +++ b/backends/qualcomm/_passes/annotate_quant_attrs.py @@ -7,10 +7,8 @@ from typing import Any, Dict import torch -from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter +from executorch.backends.qualcomm.builders.utils import get_parameter from executorch.backends.qualcomm.utils.constants import ( - QCOM_AXIS, - QCOM_BLOCK_SIZE, QCOM_DTYPE, QCOM_ENCODING, QCOM_QUANT_ATTRS, @@ -18,11 +16,8 @@ QCOM_QUANT_MIN, QCOM_REQUANTIZE, QCOM_SCALE, - QCOM_SCALES, QCOM_ZERO_POINT, - QCOM_ZERO_POINTS, ) -from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from .utils import dq_ops, get_quant_attrs, q_ops @@ -101,43 +96,9 @@ def _annotate_requant(self, n): n.args[0].meta.setdefault(QCOM_REQUANTIZE, {}) n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs - # Dequant all the fold_quant parameters back to fp32. - # If an operation is not supported by QNN and got fallback, it will expect a fp32 param. - def _dequant_fold_params(self, n, quant_attrs, param): - if quant_attrs[QCOM_ENCODING] in [ - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default - ]: - dim, axis = param.dim(), quant_attrs[QCOM_AXIS] - scales = self._expand(quant_attrs[QCOM_SCALES], dim, axis) - offsets = self._expand(quant_attrs[QCOM_ZERO_POINTS], dim, axis) - param = param.sub(offsets).mul(scales).to(torch.float32).contiguous() - elif quant_attrs[QCOM_ENCODING] in [ - exir_ops.edge.pt2e_quant.dequantize_affine.default - ]: - param = torch.ops.pt2e_quant.dequantize_affine( - param, - block_size=quant_attrs[QCOM_BLOCK_SIZE], - scale=quant_attrs[QCOM_SCALE], - zero_point=quant_attrs[QCOM_ZERO_POINT], - input_dtype=quant_attrs[QCOM_DTYPE], - quant_min=quant_attrs[QCOM_QUANT_MIN], - quant_max=quant_attrs[QCOM_QUANT_MAX], - output_dtype=torch.float32, - ) - else: - scale = quant_attrs[QCOM_SCALE] - offset = quant_attrs[QCOM_ZERO_POINT] - param = param.sub(offset).mul(scale).to(torch.float32).contiguous() - - set_parameter(param, n.args[0], self.edge_program) - n.args[0].meta["val"] = param - def _annotate_quant_attrs( self, graph_module: torch.fx.GraphModule ) -> torch.fx.GraphModule: - # Keep track of const params that has been dequant, so it does not get - # dequant multiple times if the const param has more than 1 user - visited_const_param = set() for n in graph_module.graph.nodes: self._annotate_requant(n) # With fold_quant enabled, check if the input of dq op is quantized param. @@ -149,10 +110,6 @@ def _annotate_quant_attrs( quant_attrs = get_quant_attrs(self.edge_program, n) self._annotate_source_nodes(n, quant_attrs) - if param is not None and n.args[0] not in visited_const_param: - visited_const_param.add(n.args[0]) - self._dequant_fold_params(n, quant_attrs, param) - return graph_module def call(self, graph_module: torch.fx.GraphModule): diff --git a/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py b/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py index 72dc29c2880..dabe0243a47 100644 --- a/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py +++ b/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py @@ -5,10 +5,8 @@ # LICENSE file in the root directory of this source tree. import torch -import torch.nn as nn from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE -from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from .utils import copy_meta @@ -23,16 +21,43 @@ class ConvertConv1dToConv2d(ExportPass): def __init__(self, edge_program: torch.export.ExportedProgram): super(ConvertConv1dToConv2d, self).__init__() self.edge_program = edge_program + self.conv_op_map = { + torch.ops.aten.conv1d.default: torch.ops.aten.conv2d.default, + torch.ops.aten.conv_transpose1d.default: torch.ops.aten.conv_transpose2d.input, + } + + def append_qdq( + self, + graph_module: torch.fx.GraphModule, + node: torch.fx.Node, + qdq_node: torch.fx.Node, + ): + q_op = torch.ops.quantized_decomposed.quantize_per_tensor.default + dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default + if qdq_node.target not in {q_op, dq_op}: + return node + + with graph_module.graph.inserting_after(node): + q_args = (node, *qdq_node.args[1:]) + q_node = graph_module.graph.create_node("call_function", q_op, q_args) + q_node.meta = copy_meta(node.meta) + q_node.meta["val"] = q_node.meta["val"].to(q_args[-1]) + with graph_module.graph.inserting_after(q_node): + dq_args = (q_node, *qdq_node.args[1:]) + dq_node = graph_module.graph.create_node( + "call_function", dq_op, dq_args + ) + dq_node.meta = copy_meta(node.meta) + + return dq_node def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph - conv_op = exir_ops.edge.aten.convolution.default for node in graph.nodes: - if node.target == conv_op and node.meta["val"].dim() == 3: - + if node.target in self.conv_op_map: input_node = node.args[0] with graph_module.graph.inserting_after(input_node): - unsqueeze_op = exir_ops.edge.aten.unsqueeze_copy.default + unsqueeze_op = torch.ops.aten.unsqueeze_copy.default unsqueeze_node = graph.create_node( "call_function", unsqueeze_op, @@ -44,52 +69,88 @@ def call(self, graph_module: torch.fx.GraphModule): unsqueeze_node.meta = copy_meta( input_node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)} ) + qdq_node_after_unsqueeze = self.append_qdq( + graph_module=graph_module, + node=unsqueeze_node, + qdq_node=input_node, + ) - with graph_module.graph.inserting_after(unsqueeze_node): - - filter_node = node.args[1] + with graph_module.graph.inserting_after(qdq_node_after_unsqueeze): + filter_arg = node.args[1] + filter_node = ( + filter_arg + if filter_arg.op == "placeholder" + else node.args[1].args[0] + ) filter_node.meta["val"] = ( filter_node.meta["val"].unsqueeze(2).contiguous() ) - filter_tensor = get_parameter(filter_node, self.edge_program) - # Ensure tensor is nn.Parameter type, so program does not fail during edge_program._validate() - filter_tensor = nn.Parameter(filter_tensor.unsqueeze(2)) - set_parameter(filter_tensor, filter_node, self.edge_program) + filter_tensor = get_parameter( + filter_node, self.edge_program + ).unsqueeze(2) + set_parameter( + ( + torch.nn.Parameter(filter_tensor) + if filter_tensor.dtype == torch.float + else filter_tensor + ), + filter_node, + self.edge_program, + ) + num_args = len(node.args) bias_node = node.args[2] - stride = [1] + node.args[3] - padding = [0] + node.args[4] - dilation = [1] + node.args[5] - transpose = node.args[6] - output_padding = [0] + node.args[7] - groups = node.args[8] - - conv2d_node = graph.create_node( - "call_function", - conv_op, - ( - unsqueeze_node, - filter_node, + stride = [1] + node.args[3] if num_args > 3 else [1, 1] + padding = [0] + node.args[4] if num_args > 4 else [0, 0] + if node.target == torch.ops.aten.conv1d.default: + dilation = [1] + node.args[5] if num_args > 5 else [1, 1] + groups = node.args[6] if num_args > 5 else 1 + conv_args = ( + qdq_node_after_unsqueeze, + node.args[1], bias_node, stride, padding, dilation, - transpose, + groups, + ) + else: + output_padding = ( + [0] + node.args[5] if num_args > 5 else [0, 0] + ) + groups = node.args[6] if num_args > 6 else 1 + dilation = [1] + node.args[7] if num_args > 7 else [1, 1] + conv_args = ( + qdq_node_after_unsqueeze, + node.args[1], + bias_node, + stride, + padding, output_padding, groups, - ), + dilation, + ) + conv2d_node = graph.create_node( + "call_function", + self.conv_op_map[node.target], + conv_args, ) conv2d_node.meta = copy_meta( node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)} ) + qdq_node_after_conv2d = self.append_qdq( + graph_module=graph_module, + node=conv2d_node, + qdq_node=list(node.users)[0], + ) - with graph_module.graph.inserting_after(conv2d_node): - squeeze_op = exir_ops.edge.aten.squeeze_copy.dims + with graph_module.graph.inserting_after(qdq_node_after_conv2d): + squeeze_op = torch.ops.aten.squeeze_copy.dims squeeze_node = graph.create_node( "call_function", squeeze_op, ( - conv2d_node, + qdq_node_after_conv2d, [2], ), ) @@ -102,8 +163,10 @@ def call(self, graph_module: torch.fx.GraphModule): QCOM_REQUANTIZE ] conv2d_node.meta.pop(QCOM_REQUANTIZE, None) + for user in node.users.copy(): user.replace_input_with(node, squeeze_node) + graph.eliminate_dead_code() graph_module.recompile() return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/expand_broadcast_tensor_shape.py b/backends/qualcomm/_passes/expand_broadcast_tensor_shape.py index 829b3757e06..4fe87604fc1 100644 --- a/backends/qualcomm/_passes/expand_broadcast_tensor_shape.py +++ b/backends/qualcomm/_passes/expand_broadcast_tensor_shape.py @@ -9,6 +9,8 @@ from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.passes import dead_code_elimination_pass +from .utils import dq_ops + class ExpandBroadcastTensorShape(ExportPass): """ @@ -45,9 +47,13 @@ def traverse_broadcast_node(self, graph_module: torch.fx.GraphModule): exir_ops.edge.aten.view_copy.default, (arg, tuple(new_rank)), ) + # try skip dq_ops to get correct param node if applicable + arg_meta = ( + arg.args[0].meta if arg.target in dq_ops else arg.meta + ) # meta needs to be copied elementwisely for fake-tensor # to be updated correctly and not affect meta of arg - for k, v in arg.meta.items(): + for k, v in arg_meta.items(): reshape_node.meta[k] = v reshape_node.meta["val"] = reshape_node.meta["val"].reshape( new_rank diff --git a/backends/qualcomm/_passes/fold_qdq.py b/backends/qualcomm/_passes/fold_qdq.py index bc17b2fae1f..accf66d4c35 100644 --- a/backends/qualcomm/_passes/fold_qdq.py +++ b/backends/qualcomm/_passes/fold_qdq.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch +from executorch.backends.qualcomm.builders.utils import is_parameter +from executorch.backends.qualcomm.utils.constants import QCOM_BYPASS_NODE from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.passes import dead_code_elimination_pass @@ -16,23 +18,38 @@ class FoldQDQ(ExportPass): Erase QDQ pattern. """ - def __init__(self): + def __init__(self, edge_program: torch.export.ExportedProgram, force_fold=False): super(FoldQDQ, self).__init__() + self.edge_program = edge_program + self.force_fold = force_fold - def _fold(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: + def _annotate_bypass(self, node): + node.meta[QCOM_BYPASS_NODE] = True + for arg in node.args: + if isinstance(arg, torch.fx.Node) and arg.op == "call_function": + self._annotate_bypass(arg) + + def _fold_dq(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: # remove dq for n in graph_module.graph.nodes: user_list = list(n.users.keys()) if n.target not in dq_ops: continue - for user_n in user_list: - user_n.replace_input_with(n, n.args[0]) - graph_module.graph.erase_node(n) + # skip parameters & buffers + if not self.force_fold and is_parameter(n.args[0], self.edge_program): + self._annotate_bypass(n) + else: + for user_n in user_list: + user_n.replace_input_with(n, n.args[0]) + graph_module.graph.erase_node(n) + + def _fold_q(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: # remove q for n in graph_module.graph.nodes: if n.target not in q_ops: continue + to_be_removed = [n] source_n = n.args[0] @@ -57,7 +74,8 @@ def _fold(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: graph_module.graph.erase_node(n) def call(self, graph_module: torch.fx.GraphModule): - self._fold(graph_module) + self._fold_dq(graph_module) + self._fold_q(graph_module) graph_module.recompile() dead_code_elimination_pass(graph_module) return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index c98f27db120..63c303eb689 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -200,6 +200,9 @@ def transform_for_export_pipeline(self, exported_program: ExportedProgram): self.add_pass(DecomposeScaledDotProductAttention()) self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) self.add_pass(DecomposeExpM1()) + # this pass will rewrite state_dict, it needs to be accomplished before + # to_edge_transform_and_lower + self.add_pass(ConvertConv1dToConv2d(exported_program)) self.add_pass(ConvertSquareToPow()) self.add_pass(LiftConstantScalarOperands()) self._transform(exported_program.graph_module) @@ -207,6 +210,7 @@ def transform_for_export_pipeline(self, exported_program: ExportedProgram): return ep def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram): + self.add_pass(FoldQDQ(exported_program, force_fold=True)) self.add_pass(InsertRequantize()) self.add_pass(InsertIOQDQ(exported_program)) self.add_pass(LayoutTransform(exported_program, insert_permute=True)) diff --git a/backends/qualcomm/_passes/replace_index_put_input.py b/backends/qualcomm/_passes/replace_index_put_input.py index dcdf2bb3a7f..93ee21bfc7c 100644 --- a/backends/qualcomm/_passes/replace_index_put_input.py +++ b/backends/qualcomm/_passes/replace_index_put_input.py @@ -33,7 +33,8 @@ def call(self, graph_module: torch.fx.GraphModule): copy_node := list(node.users)[0] ) and copy_node.target == exir_ops.edge.aten.copy.default: m_buffer_node = copy_node.args[0] - bad_frozen_node = node.args[0] + dq_node = node.args[0] + bad_frozen_node = dq_node.args[0] if QCOM_QUANT_ATTRS in bad_frozen_node.meta: m_buffer_node.meta[QCOM_QUANT_ATTRS] = bad_frozen_node.meta[ QCOM_QUANT_ATTRS @@ -43,8 +44,8 @@ def call(self, graph_module: torch.fx.GraphModule): m_buffer_node.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING] ] ) - with graph.inserting_after(bad_frozen_node): - node.replace_input_with(bad_frozen_node, m_buffer_node) + with graph.inserting_after(dq_node): + node.replace_input_with(dq_node, m_buffer_node) else: continue diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 46d9e0cde76..10dcbb07aac 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -103,7 +103,6 @@ def get_passes_dependency_for_capture_program(): AnnotateStack: [RemoveRedundancy], AnnotateUnbind: [RemoveRedundancy], ConvertBmmToMatmul: [RecomposePixelUnshuffle], - ConvertConv1dToConv2d: [FoldQDQ], ConvertUpsampleBicubicWithBilinear: [RemoveRedundancy], DecomposeAny: [RemoveRedundancy], DecomposeLinalgVectorNorm: [RemoveRedundancy], diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index 783a53dd645..22f0852941c 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -227,7 +227,7 @@ Now, we can start to fill in function body step by step: 2. Define input gamma / beta tensors: ```python - weight_node = node.args[2] + weight_node = self.get_node(node.args[2]) weight_tensor = get_parameter(weight_node, self.edge_program) weight_tensor_wrapper = self.define_tensor( weight_node, @@ -237,7 +237,7 @@ Now, we can start to fill in function body step by step: nodes_to_wrappers, ) - bias_node = node.args[3] + bias_node = self.get_node(node.args[3]) bias_tensor = get_parameter(bias_node, self.edge_program) bias_tensor_wrapper = self.define_tensor( bias_node, diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 7965a30caea..5e9520d4c05 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -11,6 +11,7 @@ import numpy as np import torch +from executorch.backends.qualcomm._passes.utils import dq_ops from executorch.backends.qualcomm.utils.constants import ( QCOM_AXIS, QCOM_AXIS_ORDER, @@ -18,7 +19,6 @@ QCOM_BLOCK_SCALE_BITWIDTH, QCOM_BLOCK_SCALE_OFFSET, QCOM_BLOCK_SCALES, - QCOM_BLOCK_SIZE, QCOM_BLOCK_STORAGE_TYPE, QCOM_DTYPE, QCOM_ENCODING, @@ -95,6 +95,19 @@ def __init__( self.edge_program = edge_program self.enable_tensor_dump = enable_tensor_dump + def get_node(self, node): + """ + Utility to skip dequantize node for frozen param + """ + return node.args[0] if node is not None and node.target in dq_ops else node + + def get_first_user(self, node): + """ + Utility to skip dequantize user for frozen param + """ + user_0 = list(node.users)[0] + return user_0 if user_0.target not in dq_ops else self.get_first_user(user_0) + def get_tensor(self, input_node, op_node, idx=None): """ Get tensor value/shape with axis_order @@ -142,7 +155,9 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict): # symmetric quantization is required scale_offset.append(PyQnnWrapper.Qnn_ScaleOffset_t(max_scale, 0)) - if "convolution" in list(node.users)[0].target.__name__: + # skip dequantize op, e.g. frozen_param -> dq -> conv2d + user_0 = self.get_first_user(node) + if "convolution" in user_0.target.__name__: # OIHW (pytorch) -> HWIO (QNN) quant_config[QCOM_AXIS] = 3 quant_config[QCOM_AXIS_ORDER] = (2, 3, 1, 0) @@ -178,14 +193,11 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict): PyQnnWrapper.Qnn_ScaleOffset_t(scales[i], -zero_points[i]) ) - user_0 = list(node.users)[0] + # skip dequantize op, e.g. frozen_param -> dq -> conv2d + user_0 = self.get_first_user(node) # Memory layout of QNN conv weight always ends in Output. Like conv2d is HWIO - if ( - "convolution" in user_0.target.__name__ - and list(node.users)[0].args[1] == node - ): + if "convolution" in user_0.target.__name__: quant_config[QCOM_AXIS] = 3 - else: quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS] @@ -256,33 +268,21 @@ def get_quant_encoding_conf( def get_quant_tensor_value( self, tensor: torch.Tensor, quant_attrs: Dict, quant_configs: Dict ) -> torch.Tensor: - dtype = quant_configs[QCOM_DTYPE] - if quant_attrs[QCOM_ENCODING] in PER_TENSOR_ENCODING: + # params should have been quantized by framework + # here we're handling constant operators like arange, full, etc. + if tensor.dtype == torch.float32: + assert quant_attrs[QCOM_ENCODING] in PER_TENSOR_ENCODING, ( + f"unrecongnized quantization attribute detected {quant_attrs[QCOM_ENCODING]}", + ) scale = quant_attrs[QCOM_SCALE] zero_point = quant_attrs[QCOM_ZERO_POINT] - tensor = tensor.div(scale).add(zero_point).round().to(dtype) - elif quant_attrs[QCOM_ENCODING] in PER_CHANNEL_ENCODING: - scale = quant_attrs[QCOM_SCALES] - zero_point = quant_attrs[QCOM_ZERO_POINTS] - tensor = tensor.div(scale).add(zero_point).round().to(dtype) - else: # per_block - if axis_order := quant_configs.get(QCOM_AXIS_ORDER, None): - origin_order = tuple( - axis_order.index(x) for x in range(len(axis_order)) - ) - tensor = tensor.permute(origin_order) - tensor = torch.ops.pt2e_quant.quantize_affine( - tensor, - block_size=quant_attrs[QCOM_BLOCK_SIZE], - scale=quant_attrs[QCOM_SCALE], - zero_point=quant_attrs[QCOM_ZERO_POINT], - output_dtype=dtype, - quant_min=quant_attrs[QCOM_QUANT_MIN], - quant_max=quant_attrs[QCOM_QUANT_MAX], + tensor = ( + tensor.div(scale).add(zero_point).round().to(quant_configs[QCOM_DTYPE]) ) - if axis_order: - tensor = tensor.permute(axis_order) - + # Since we're using torch.int32 to store 16bit data + # need to make it compact here for QNN to correctly retrieve data + if quant_configs.get(QCOM_DTYPE) == torch.uint16: + tensor = tensor.to(torch.uint16) # Make the backends access data correctly if quant_configs.get(QCOM_BITWIDTH) == 4: mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8) diff --git a/backends/qualcomm/builders/op_abs.py b/backends/qualcomm/builders/op_abs.py index 002ffe85208..2209ffc792c 100644 --- a/backends/qualcomm/builders/op_abs.py +++ b/backends/qualcomm/builders/op_abs.py @@ -35,7 +35,7 @@ def define_node( ) abs_output_tensors = [output_tensor_wrapper] - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor_wrapper = self.define_tensor( input_node, node, diff --git a/backends/qualcomm/builders/op_adaptive_avg_pool2d.py b/backends/qualcomm/builders/op_adaptive_avg_pool2d.py index c944e1646e7..777e1f61ada 100644 --- a/backends/qualcomm/builders/op_adaptive_avg_pool2d.py +++ b/backends/qualcomm/builders/op_adaptive_avg_pool2d.py @@ -28,7 +28,7 @@ def define_node( nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_add.py b/backends/qualcomm/builders/op_add.py index b5edfd7bb52..f8fb31fb725 100644 --- a/backends/qualcomm/builders/op_add.py +++ b/backends/qualcomm/builders/op_add.py @@ -37,7 +37,7 @@ def define_node( add_input_tensors = [] for index in range(2): - input_node = node.args[index] + input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE diff --git a/backends/qualcomm/builders/op_amax.py b/backends/qualcomm/builders/op_amax.py index 099004a4bcf..62c17b8dfcd 100644 --- a/backends/qualcomm/builders/op_amax.py +++ b/backends/qualcomm/builders/op_amax.py @@ -28,7 +28,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_and.py b/backends/qualcomm/builders/op_and.py index 44e6f2893f5..22b63e0d6ff 100644 --- a/backends/qualcomm/builders/op_and.py +++ b/backends/qualcomm/builders/op_and.py @@ -37,7 +37,7 @@ def define_node( and_input_tensors = [] for index in range(2): - input_node = node.args[index] + input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE diff --git a/backends/qualcomm/builders/op_argmin.py b/backends/qualcomm/builders/op_argmin.py index 5630b02a5cc..fa3fad4a61b 100644 --- a/backends/qualcomm/builders/op_argmin.py +++ b/backends/qualcomm/builders/op_argmin.py @@ -27,7 +27,7 @@ def define_node( nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: op_wrapper_list = [] - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) output_tensor = self.get_tensor(node, node) argmin_inp_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/builders/op_avg_pool2d.py b/backends/qualcomm/builders/op_avg_pool2d.py index 394d4008587..f4762e8bb5a 100644 --- a/backends/qualcomm/builders/op_avg_pool2d.py +++ b/backends/qualcomm/builders/op_avg_pool2d.py @@ -28,7 +28,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_batch_norm.py b/backends/qualcomm/builders/op_batch_norm.py index 9aed1401875..ec0a7c39348 100644 --- a/backends/qualcomm/builders/op_batch_norm.py +++ b/backends/qualcomm/builders/op_batch_norm.py @@ -40,14 +40,22 @@ def update_encoding(self, node: torch.fx.Node, tensor: torch.Tensor, eps): if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): # scale value equals to zero will cause failure in HTP diff = max(abs(tensor.max()), abs(tensor.min())) + eps - quant_attrs[QCOM_SCALE] = diff / quant_attrs[QCOM_QUANT_MAX] + quant_attrs[QCOM_SCALE] = (diff / quant_attrs[QCOM_QUANT_MAX]).item() + + def try_dequantize(self, node: torch.fx.Node, tensor: torch.Tensor): + if tensor.dtype == torch.float: + return tensor + + scale = node.meta[QCOM_QUANT_ATTRS][QCOM_SCALE] + offset = node.meta[QCOM_QUANT_ATTRS][QCOM_ZERO_POINT] + return tensor.sub(offset).mul(scale).to(torch.float32).contiguous() def define_node( self, node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) eps = 1e-9 @@ -78,9 +86,12 @@ def define_node( batch_norm_output_tensors = [output_tensor_wrapper] n_feature = output_tensor.shape[-1 if QCOM_AXIS_ORDER in node.meta else 1] - filter_node = node.args[1] + filter_node = self.get_node(node.args[1]) if filter_node is not None: - filter_tensor = get_parameter(filter_node, self.edge_program) + # dequantize here for post-process + filter_tensor = self.try_dequantize( + filter_node, get_parameter(filter_node, self.edge_program) + ) else: # 'graph', 'name', 'op', 'target', 'args', and 'kwargs' filter_node = torch.fx.Node( @@ -110,9 +121,12 @@ def define_node( ) batch_norm_input_tensors.append(filter_tensor_wrapper) - bias_node = node.args[2] + bias_node = self.get_node(node.args[2]) if bias_node is not None: - bias_tensor = get_parameter(bias_node, self.edge_program) + # dequantize here for post-process + bias_tensor = self.try_dequantize( + bias_node, get_parameter(bias_node, self.edge_program) + ) amount = (filter_tensor * mean_tensor) / torch.sqrt(var_tensor + eps) bias_tensor = bias_tensor - amount self.update_encoding(bias_node, bias_tensor, eps) diff --git a/backends/qualcomm/builders/op_bmm.py b/backends/qualcomm/builders/op_bmm.py index 46fbff1cc7e..d473d085490 100644 --- a/backends/qualcomm/builders/op_bmm.py +++ b/backends/qualcomm/builders/op_bmm.py @@ -27,7 +27,7 @@ def define_node( ) -> PyQnnWrapper.PyQnnOpWrapper: bmm_input_tensors = [] for index in range(2): - input_node = node.args[index] + input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/builders/op_cat.py b/backends/qualcomm/builders/op_cat.py index 7f160856390..09f99396589 100644 --- a/backends/qualcomm/builders/op_cat.py +++ b/backends/qualcomm/builders/op_cat.py @@ -32,7 +32,7 @@ def define_node( list_of_tensor_wrappers = [] for tensor_input in list_of_tensors: - input_tensor = self.get_tensor(tensor_input, node) + input_tensor = self.get_tensor(self.get_node(tensor_input), node) list_of_tensor_wrappers.append( self.define_tensor( tensor_input, diff --git a/backends/qualcomm/builders/op_ceil.py b/backends/qualcomm/builders/op_ceil.py index 19fe14d6392..f0a43846d11 100644 --- a/backends/qualcomm/builders/op_ceil.py +++ b/backends/qualcomm/builders/op_ceil.py @@ -25,7 +25,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_clamp.py b/backends/qualcomm/builders/op_clamp.py index 0f9a9ffa196..e80c99db352 100644 --- a/backends/qualcomm/builders/op_clamp.py +++ b/backends/qualcomm/builders/op_clamp.py @@ -27,7 +27,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_conv2d.py b/backends/qualcomm/builders/op_conv2d.py index c019a835223..5a168ca103a 100644 --- a/backends/qualcomm/builders/op_conv2d.py +++ b/backends/qualcomm/builders/op_conv2d.py @@ -104,7 +104,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) assert ( input_tensor.dim() == 4 @@ -117,7 +117,7 @@ def define_node( nodes_to_wrappers, ) - filter_node = node.args[1] + filter_node = self.get_node(node.args[1]) filter_tensor = get_parameter(filter_node, self.edge_program) # weight of pytorch OIHW(conv2d) | IOHW(conv_transpose2d), yet QNN is HWIO is_transpose_conv = cast(bool, node.args[6]) @@ -133,7 +133,7 @@ def define_node( conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper] if node.args[2] is not None: - bias_node = node.args[2] + bias_node = self.get_node(node.args[2]) bias_tensor = get_parameter(bias_node, self.edge_program) bias_tensor_wrapper = self.define_tensor( bias_node, diff --git a/backends/qualcomm/builders/op_cos.py b/backends/qualcomm/builders/op_cos.py index 589bf3ef88e..69c0d40a026 100644 --- a/backends/qualcomm/builders/op_cos.py +++ b/backends/qualcomm/builders/op_cos.py @@ -25,7 +25,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_cum_sum.py b/backends/qualcomm/builders/op_cum_sum.py index f62485bc519..dceaea83345 100644 --- a/backends/qualcomm/builders/op_cum_sum.py +++ b/backends/qualcomm/builders/op_cum_sum.py @@ -37,7 +37,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_depth_to_space.py b/backends/qualcomm/builders/op_depth_to_space.py index 56c57b4bd5e..357b7a81039 100644 --- a/backends/qualcomm/builders/op_depth_to_space.py +++ b/backends/qualcomm/builders/op_depth_to_space.py @@ -28,7 +28,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_div.py b/backends/qualcomm/builders/op_div.py index ce3f96abc7f..399e914e290 100644 --- a/backends/qualcomm/builders/op_div.py +++ b/backends/qualcomm/builders/op_div.py @@ -37,7 +37,7 @@ def define_node( div_input_tensors = [] for index in range(2): - input_node = node.args[index] + input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE diff --git a/backends/qualcomm/builders/op_elu.py b/backends/qualcomm/builders/op_elu.py index f9cc089c7bb..f0ac422f4b8 100644 --- a/backends/qualcomm/builders/op_elu.py +++ b/backends/qualcomm/builders/op_elu.py @@ -28,7 +28,7 @@ def define_node( nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: # tensor input - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/builders/op_embedding.py b/backends/qualcomm/builders/op_embedding.py index 5b0d1600393..ba5b1a02077 100644 --- a/backends/qualcomm/builders/op_embedding.py +++ b/backends/qualcomm/builders/op_embedding.py @@ -28,7 +28,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - weight_node = node.args[0] + weight_node = self.get_node(node.args[0]) weight_tensor = get_parameter(weight_node, self.edge_program) weight_tensor_wrapper = self.define_tensor( weight_node, diff --git a/backends/qualcomm/builders/op_eq.py b/backends/qualcomm/builders/op_eq.py index 855c5e13be6..6f33ea78bd1 100644 --- a/backends/qualcomm/builders/op_eq.py +++ b/backends/qualcomm/builders/op_eq.py @@ -37,7 +37,7 @@ def define_node( input_tensors = [] for index in range(2): - input_node = node.args[index] + input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE diff --git a/backends/qualcomm/builders/op_exp.py b/backends/qualcomm/builders/op_exp.py index 8c4794c9725..f736dec85c2 100644 --- a/backends/qualcomm/builders/op_exp.py +++ b/backends/qualcomm/builders/op_exp.py @@ -26,7 +26,7 @@ def define_node( nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: # tensor input - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/builders/op_expand.py b/backends/qualcomm/builders/op_expand.py index c098ed00c94..31d248638ab 100644 --- a/backends/qualcomm/builders/op_expand.py +++ b/backends/qualcomm/builders/op_expand.py @@ -27,7 +27,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_ge.py b/backends/qualcomm/builders/op_ge.py index 6784167aa5b..28a29829731 100644 --- a/backends/qualcomm/builders/op_ge.py +++ b/backends/qualcomm/builders/op_ge.py @@ -37,7 +37,7 @@ def define_node( input_tensors = [] for index in range(2): - input_node = node.args[index] + input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE diff --git a/backends/qualcomm/builders/op_gelu.py b/backends/qualcomm/builders/op_gelu.py index c178740448e..02356a2eef5 100644 --- a/backends/qualcomm/builders/op_gelu.py +++ b/backends/qualcomm/builders/op_gelu.py @@ -26,7 +26,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_group_norm.py b/backends/qualcomm/builders/op_group_norm.py index 26700216b53..a52569cfa7a 100644 --- a/backends/qualcomm/builders/op_group_norm.py +++ b/backends/qualcomm/builders/op_group_norm.py @@ -29,7 +29,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, @@ -39,7 +39,7 @@ def define_node( nodes_to_wrappers, ) - weight_node = node.args[1] + weight_node = self.get_node(node.args[1]) weight_tensor = get_parameter(weight_node, self.edge_program) weight_tensor_wrapper = self.define_tensor( weight_node, @@ -49,7 +49,7 @@ def define_node( nodes_to_wrappers, ) - bias_node = node.args[2] + bias_node = self.get_node(node.args[2]) bias_tensor = get_parameter(bias_node, self.edge_program) bias_tensor_wrapper = self.define_tensor( bias_node, diff --git a/backends/qualcomm/builders/op_gt.py b/backends/qualcomm/builders/op_gt.py index 6c311f42b7f..8c1ef3a600c 100644 --- a/backends/qualcomm/builders/op_gt.py +++ b/backends/qualcomm/builders/op_gt.py @@ -37,7 +37,7 @@ def define_node( input_tensors = [] for index in range(2): - input_node = node.args[index] + input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE diff --git a/backends/qualcomm/builders/op_hardsigmoid.py b/backends/qualcomm/builders/op_hardsigmoid.py index 1acc08a387d..c30cae92f55 100644 --- a/backends/qualcomm/builders/op_hardsigmoid.py +++ b/backends/qualcomm/builders/op_hardsigmoid.py @@ -28,7 +28,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_hardswish.py b/backends/qualcomm/builders/op_hardswish.py index ed28ff95f78..fb4d0a40515 100644 --- a/backends/qualcomm/builders/op_hardswish.py +++ b/backends/qualcomm/builders/op_hardswish.py @@ -26,7 +26,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_hardtanh.py b/backends/qualcomm/builders/op_hardtanh.py index 68bafaaab8b..4025a060ff3 100644 --- a/backends/qualcomm/builders/op_hardtanh.py +++ b/backends/qualcomm/builders/op_hardtanh.py @@ -28,7 +28,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_index.py b/backends/qualcomm/builders/op_index.py index ff039f9d7a8..fe6bf4262d8 100644 --- a/backends/qualcomm/builders/op_index.py +++ b/backends/qualcomm/builders/op_index.py @@ -28,7 +28,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_index_put.py b/backends/qualcomm/builders/op_index_put.py index c317cc0a8b7..9c11a6ca891 100644 --- a/backends/qualcomm/builders/op_index_put.py +++ b/backends/qualcomm/builders/op_index_put.py @@ -20,7 +20,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, @@ -50,7 +50,7 @@ def define_node( PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - value_node = node.args[2] + value_node = self.get_node(node.args[2]) value_tensor = self.get_tensor(value_node, node) diff --git a/backends/qualcomm/builders/op_instance_norm.py b/backends/qualcomm/builders/op_instance_norm.py index e7e7f14a944..828e89a97f2 100644 --- a/backends/qualcomm/builders/op_instance_norm.py +++ b/backends/qualcomm/builders/op_instance_norm.py @@ -35,7 +35,9 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node, weight_node, bias_node = node.args[0:3] + input_node = self.get_node(node.args[0]) + weight_node = self.get_node(node.args[1]) + bias_node = self.get_node(node.args[2]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_layer_norm.py b/backends/qualcomm/builders/op_layer_norm.py index 06f822014ed..5316cb1dabe 100644 --- a/backends/qualcomm/builders/op_layer_norm.py +++ b/backends/qualcomm/builders/op_layer_norm.py @@ -30,7 +30,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, @@ -53,7 +53,7 @@ def define_node( axis = [len(input_tensor.shape) - 1] axis_shape = [len(axis)] - weight_node = node.args[2] + weight_node = self.get_node(node.args[2]) weight_tensor = get_parameter(weight_node, self.edge_program) weight_tensor_wrapper = self.define_tensor( weight_node, @@ -65,7 +65,7 @@ def define_node( layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper] - bias_node = node.args[3] + bias_node = self.get_node(node.args[3]) if bias_node is not None: bias_tensor = get_parameter(bias_node, self.edge_program) bias_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/builders/op_le.py b/backends/qualcomm/builders/op_le.py index 1dd2a06b777..e5784049c5c 100644 --- a/backends/qualcomm/builders/op_le.py +++ b/backends/qualcomm/builders/op_le.py @@ -37,7 +37,7 @@ def define_node( input_tensors = [] for index in range(2): - input_node = node.args[index] + input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE diff --git a/backends/qualcomm/builders/op_linear.py b/backends/qualcomm/builders/op_linear.py index 71b6072b9e5..71716e81bca 100644 --- a/backends/qualcomm/builders/op_linear.py +++ b/backends/qualcomm/builders/op_linear.py @@ -34,7 +34,7 @@ def define_node( nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: linear_input_tensors = [] - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, @@ -45,7 +45,7 @@ def define_node( ) linear_input_tensors.append(input_tensor_wrapper) - weight_node = node.args[1] + weight_node = self.get_node(node.args[1]) if ( quant_attrs := weight_node.meta.get(QCOM_QUANT_ATTRS) ) and QCOM_SCALES in quant_attrs: @@ -67,7 +67,7 @@ def define_node( linear_input_tensors.append(weight_tensor_wrapper) if len(node.args) >= 3: - bias_node = node.args[2] + bias_node = self.get_node(node.args[2]) # TODO remove this when qnn sdk support if QCOM_SCALES in bias_node.meta.get(QCOM_QUANT_ATTRS, {}): diff --git a/backends/qualcomm/builders/op_log.py b/backends/qualcomm/builders/op_log.py index bcc40aa6268..65125e42316 100644 --- a/backends/qualcomm/builders/op_log.py +++ b/backends/qualcomm/builders/op_log.py @@ -25,7 +25,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) log_inp_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_log_softmax.py b/backends/qualcomm/builders/op_log_softmax.py index d395d5eb66e..2d6c857591e 100644 --- a/backends/qualcomm/builders/op_log_softmax.py +++ b/backends/qualcomm/builders/op_log_softmax.py @@ -27,7 +27,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) log_softmax_inp_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/builders/op_logical_not.py b/backends/qualcomm/builders/op_logical_not.py index 457a1007ada..1eed7d894de 100644 --- a/backends/qualcomm/builders/op_logical_not.py +++ b/backends/qualcomm/builders/op_logical_not.py @@ -25,7 +25,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_lt.py b/backends/qualcomm/builders/op_lt.py index b4a080efc38..9494aac9d29 100644 --- a/backends/qualcomm/builders/op_lt.py +++ b/backends/qualcomm/builders/op_lt.py @@ -37,7 +37,7 @@ def define_node( input_tensors = [] for index in range(2): - input_node = node.args[index] + input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE diff --git a/backends/qualcomm/builders/op_matmul.py b/backends/qualcomm/builders/op_matmul.py index 577bcb12a42..8d45424bd62 100644 --- a/backends/qualcomm/builders/op_matmul.py +++ b/backends/qualcomm/builders/op_matmul.py @@ -27,7 +27,7 @@ def define_node( ) -> PyQnnWrapper.PyQnnOpWrapper: matmul_input_tensors = [] for index in range(2): - input_node = node.args[index] + input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/builders/op_max.py b/backends/qualcomm/builders/op_max.py index 7d41358a266..57e119922ed 100644 --- a/backends/qualcomm/builders/op_max.py +++ b/backends/qualcomm/builders/op_max.py @@ -37,7 +37,7 @@ def define_node( min_input_tensors = [] for index in range(2): - input_node = node.args[index] + input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE diff --git a/backends/qualcomm/builders/op_max_pool2d.py b/backends/qualcomm/builders/op_max_pool2d.py index 8d0087eb2c6..a0ef685acd0 100644 --- a/backends/qualcomm/builders/op_max_pool2d.py +++ b/backends/qualcomm/builders/op_max_pool2d.py @@ -28,7 +28,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_mean_dim.py b/backends/qualcomm/builders/op_mean_dim.py index 313b24420db..8fb0e9e3c95 100644 --- a/backends/qualcomm/builders/op_mean_dim.py +++ b/backends/qualcomm/builders/op_mean_dim.py @@ -28,7 +28,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_min.py b/backends/qualcomm/builders/op_min.py index 0df2796974d..72224500b0e 100644 --- a/backends/qualcomm/builders/op_min.py +++ b/backends/qualcomm/builders/op_min.py @@ -37,7 +37,7 @@ def define_node( min_input_tensors = [] for index in range(2): - input_node = node.args[index] + input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE diff --git a/backends/qualcomm/builders/op_mul.py b/backends/qualcomm/builders/op_mul.py index 3138d3b8c9b..36e0c91cf7a 100644 --- a/backends/qualcomm/builders/op_mul.py +++ b/backends/qualcomm/builders/op_mul.py @@ -37,7 +37,7 @@ def define_node( mul_input_tensors = [] for index in range(2): - input_node = node.args[index] + input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE diff --git a/backends/qualcomm/builders/op_ne.py b/backends/qualcomm/builders/op_ne.py index 0227b02efbf..e9b723a88c5 100644 --- a/backends/qualcomm/builders/op_ne.py +++ b/backends/qualcomm/builders/op_ne.py @@ -8,14 +8,6 @@ import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper import torch -from executorch.backends.qualcomm.utils.constants import ( - QCOM_QUANT_ATTRS, - QCOM_QUANT_MAX, - QCOM_QUANT_MIN, - QCOM_SCALE, - QCOM_ZERO_POINT, -) -from executorch.exir.dialects._ops import ops as exir_ops from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpElementWiseNotEqual, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -45,38 +37,9 @@ def define_node( input_tensors = [] for index in range(2): - input_node = node.args[index] - if isinstance(input_node, torch.fx.Node): - input_tensor = self.get_tensor(input_node, node) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE - else: - scalar = input_node - input_tensor = torch.tensor(scalar, dtype=torch.float32) - tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC - - # 'graph', 'name', 'op', 'target', 'args', and 'kwargs' - input_node = torch.fx.Node( - node.graph, - node.name + "_runtime_scalar", - "call_function", - exir_ops.edge.aten.scalar_tensor.default, - (), # args - {}, # kwargs - ) - # Because the output data type of the ne node is boolean. - # We need to take the quant attr from the non-scalar node. - if quant_attrs := node.args[index ^ 1].meta.get(QCOM_QUANT_ATTRS): - quant_attrs = quant_attrs.copy() - quant_range = ( - quant_attrs[QCOM_QUANT_MAX] - quant_attrs[QCOM_QUANT_MIN] - ) - quant_attrs[QCOM_ZERO_POINT] = ( - 0 if scalar >= 0 else quant_attrs[QCOM_QUANT_MAX] - ) - quant_attrs[QCOM_SCALE] = ( - scalar / quant_range if scalar >= 0 else -scalar / quant_range - ) - input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs + input_node = self.get_node(node.args[index]) + input_tensor = self.get_tensor(input_node, node) + tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_neg.py b/backends/qualcomm/builders/op_neg.py index a950a1887ab..fd48cbe2791 100644 --- a/backends/qualcomm/builders/op_neg.py +++ b/backends/qualcomm/builders/op_neg.py @@ -24,7 +24,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) neg_inp_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_or.py b/backends/qualcomm/builders/op_or.py index c2751744788..483831db0f7 100644 --- a/backends/qualcomm/builders/op_or.py +++ b/backends/qualcomm/builders/op_or.py @@ -37,7 +37,7 @@ def define_node( or_input_tensors = [] for index in range(2): - input_node = node.args[index] + input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE diff --git a/backends/qualcomm/builders/op_pad.py b/backends/qualcomm/builders/op_pad.py index 5ec34065f8b..7b210ed6838 100644 --- a/backends/qualcomm/builders/op_pad.py +++ b/backends/qualcomm/builders/op_pad.py @@ -27,7 +27,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) pad_inp_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_pow.py b/backends/qualcomm/builders/op_pow.py index 3e89bdcfc4d..996d3b353e2 100644 --- a/backends/qualcomm/builders/op_pow.py +++ b/backends/qualcomm/builders/op_pow.py @@ -37,7 +37,7 @@ def define_node( pow_output_tensors = [output_tensor_wrapper] # tensor input - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE @@ -51,7 +51,7 @@ def define_node( ) # exp input - exp_node = node.args[1] + exp_node = self.get_node(node.args[1]) exp_tensor = self.get_tensor(exp_node, node) exp_tensor_wrapper = self.define_tensor( exp_node, diff --git a/backends/qualcomm/builders/op_prelu.py b/backends/qualcomm/builders/op_prelu.py index e35839f535e..69ea5e005a7 100644 --- a/backends/qualcomm/builders/op_prelu.py +++ b/backends/qualcomm/builders/op_prelu.py @@ -26,7 +26,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) prelu_inp_tensor_wrapper = self.define_tensor( input_node, @@ -36,23 +36,18 @@ def define_node( nodes_to_wrappers, ) - coeff_node = node.args[1] - coeff_tensor = torch.zeros(input_node.meta["val"].shape) + coeff_node = self.get_node(node.args[1]) coeff = get_parameter(coeff_node, self.edge_program) - # param nodes will be FakeTensor when doing partition - # fill in random numeric for validation - if isinstance(coeff, torch._subclasses.fake_tensor.FakeTensor): - coeff = torch.ones(coeff.shape) + coeff_tensor = torch.zeros(input_node.meta["val"].shape, dtype=coeff.dtype) # per-channel activation if coeff_node.meta["val"].shape[0] > 1: for i in range(input_node.meta["val"].shape[1]): coeff_tensor = coeff_tensor.index_fill(1, torch.tensor([i]), coeff[i]) - if QCOM_AXIS_ORDER in input_node.meta: - axis_order = input_node.meta[QCOM_AXIS_ORDER] - coeff_tensor = coeff_tensor.permute(dims=axis_order).contiguous() else: - coeff = coeff.item() - coeff_tensor = torch.full(input_tensor.shape, coeff).to(torch.float32) + coeff_tensor.fill_(coeff[0]) + + if axis_order := input_node.meta.get(QCOM_AXIS_ORDER, None): + coeff_tensor = coeff_tensor.permute(dims=axis_order).contiguous() coeff_tensor_wrapper = self.define_tensor( coeff_node, diff --git a/backends/qualcomm/builders/op_relu.py b/backends/qualcomm/builders/op_relu.py index 29335797e28..d237b84efe1 100644 --- a/backends/qualcomm/builders/op_relu.py +++ b/backends/qualcomm/builders/op_relu.py @@ -25,7 +25,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) relu_inp_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_repeat.py b/backends/qualcomm/builders/op_repeat.py index 9748f1e9619..e5867e64447 100644 --- a/backends/qualcomm/builders/op_repeat.py +++ b/backends/qualcomm/builders/op_repeat.py @@ -27,7 +27,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_reshape.py b/backends/qualcomm/builders/op_reshape.py index ff4a603fa5b..6e25c65e16d 100644 --- a/backends/qualcomm/builders/op_reshape.py +++ b/backends/qualcomm/builders/op_reshape.py @@ -25,7 +25,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_rms_norm.py b/backends/qualcomm/builders/op_rms_norm.py index aa7f9becd98..fdf49b09fef 100644 --- a/backends/qualcomm/builders/op_rms_norm.py +++ b/backends/qualcomm/builders/op_rms_norm.py @@ -36,7 +36,7 @@ def define_node( nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: # args of node : ['input', 'normalized_shape', 'weight', 'eps'] - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, @@ -60,7 +60,7 @@ def define_node( axes = [node.args[0].meta["val"].dim() - 1] axes_shape = [len(axes)] - weight_node = node.args[2] + weight_node = self.get_node(node.args[2]) weight_tensor = get_parameter(weight_node, self.edge_program) weight_tensor_wrapper = self.define_tensor( weight_node, @@ -71,7 +71,7 @@ def define_node( ) # Fake node, nn module seems to be inconsistent with document - bias_tensor = torch.zeros(weight_tensor.shape) + bias_tensor = torch.zeros(weight_tensor.shape, dtype=weight_tensor.dtype) bias_node = torch.fx.Node( node.graph, node.name + "_runtime_bias", diff --git a/backends/qualcomm/builders/op_rsqrt.py b/backends/qualcomm/builders/op_rsqrt.py index 162b485e9e5..b1995e28dde 100644 --- a/backends/qualcomm/builders/op_rsqrt.py +++ b/backends/qualcomm/builders/op_rsqrt.py @@ -25,7 +25,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) rsqrt_inp_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_scalar_tensor.py b/backends/qualcomm/builders/op_scalar_tensor.py index d236f6674df..2e9154115bc 100644 --- a/backends/qualcomm/builders/op_scalar_tensor.py +++ b/backends/qualcomm/builders/op_scalar_tensor.py @@ -13,7 +13,7 @@ @register_node_visitor -class Arange(NodeVisitor): +class ScalarTensor(NodeVisitor): target = ["scalar_tensor.default"] def __init__(self, *args) -> None: diff --git a/backends/qualcomm/builders/op_select_copy.py b/backends/qualcomm/builders/op_select_copy.py index 148888f1497..c5a7c0f7c99 100644 --- a/backends/qualcomm/builders/op_select_copy.py +++ b/backends/qualcomm/builders/op_select_copy.py @@ -28,7 +28,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_sigmoid.py b/backends/qualcomm/builders/op_sigmoid.py index ae6e6709c0a..ce820c8f4ee 100644 --- a/backends/qualcomm/builders/op_sigmoid.py +++ b/backends/qualcomm/builders/op_sigmoid.py @@ -25,7 +25,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) sigmoid_inp_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_sin.py b/backends/qualcomm/builders/op_sin.py index 8828685ac9e..f9a0b1c2e63 100644 --- a/backends/qualcomm/builders/op_sin.py +++ b/backends/qualcomm/builders/op_sin.py @@ -25,7 +25,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_slice_copy.py b/backends/qualcomm/builders/op_slice_copy.py index 8d12e03c0bb..7d3a154e9f1 100644 --- a/backends/qualcomm/builders/op_slice_copy.py +++ b/backends/qualcomm/builders/op_slice_copy.py @@ -26,7 +26,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE @@ -50,12 +50,17 @@ def define_node( dim = cast(int, node.args[1]) if dim < 0: dim = dim % len(input_tensor.shape) - start = cast(int, node.args[2]) + + start = 0 if node.args[2] is None else cast(int, node.args[2]) if start < 0: start = start % input_tensor.shape[dim] - end = min(cast(int, node.args[3]), input_tensor.shape[dim]) - if end < 0: - end = end % input_tensor.shape[dim] + + if len(node.args) > 3: + end = min(cast(int, node.args[3]), input_tensor.shape[dim]) + if end < 0: + end = end % input_tensor.shape[dim] + else: + end = input_tensor.shape[dim] input_tensor_rank = len(input_tensor.shape) ranges = [] diff --git a/backends/qualcomm/builders/op_softmax.py b/backends/qualcomm/builders/op_softmax.py index f6f826e2a40..43cc6438b9b 100644 --- a/backends/qualcomm/builders/op_softmax.py +++ b/backends/qualcomm/builders/op_softmax.py @@ -27,7 +27,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) softmax_inp_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_space_to_depth.py b/backends/qualcomm/builders/op_space_to_depth.py index 0282cf3f15a..84c79d841d8 100644 --- a/backends/qualcomm/builders/op_space_to_depth.py +++ b/backends/qualcomm/builders/op_space_to_depth.py @@ -28,7 +28,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_split_with_sizes.py b/backends/qualcomm/builders/op_split_with_sizes.py index 138f6ed60ec..b70d74aa339 100644 --- a/backends/qualcomm/builders/op_split_with_sizes.py +++ b/backends/qualcomm/builders/op_split_with_sizes.py @@ -28,7 +28,7 @@ def define_node( nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/builders/op_sqrt.py b/backends/qualcomm/builders/op_sqrt.py index 5505e92ee67..ff5a0c086e0 100644 --- a/backends/qualcomm/builders/op_sqrt.py +++ b/backends/qualcomm/builders/op_sqrt.py @@ -26,7 +26,7 @@ def define_node( nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: # tensor input - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/builders/op_squeeze.py b/backends/qualcomm/builders/op_squeeze.py index b828bb7b0b9..94d6e5a3cf9 100644 --- a/backends/qualcomm/builders/op_squeeze.py +++ b/backends/qualcomm/builders/op_squeeze.py @@ -25,7 +25,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/builders/op_stack.py b/backends/qualcomm/builders/op_stack.py index fdef148ad4d..25b7d353dc4 100644 --- a/backends/qualcomm/builders/op_stack.py +++ b/backends/qualcomm/builders/op_stack.py @@ -30,7 +30,7 @@ def define_node( input_node_list = node.args[0] stack_input_tensors = [] for input_node in input_node_list: - input_tensor = self.get_tensor(input_node, node) + input_tensor = self.get_tensor(self.get_node(input_node), node) stack_inp_tensor_wrapper = self.define_tensor( input_node, node, diff --git a/backends/qualcomm/builders/op_sub.py b/backends/qualcomm/builders/op_sub.py index 954ca9d3917..e7e5b22bb96 100644 --- a/backends/qualcomm/builders/op_sub.py +++ b/backends/qualcomm/builders/op_sub.py @@ -37,7 +37,7 @@ def define_node( sub_input_tensors = [] for index in range(2): - input_node = node.args[index] + input_node = self.get_node(node.args[index]) input_tensor = self.get_tensor(input_node, node) tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE diff --git a/backends/qualcomm/builders/op_sum_int_list.py b/backends/qualcomm/builders/op_sum_int_list.py index 74181f46cb3..fc5546f9d33 100644 --- a/backends/qualcomm/builders/op_sum_int_list.py +++ b/backends/qualcomm/builders/op_sum_int_list.py @@ -28,7 +28,7 @@ def define_node( nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_tanh.py b/backends/qualcomm/builders/op_tanh.py index ddc9fd2a2a6..c06f44b312f 100644 --- a/backends/qualcomm/builders/op_tanh.py +++ b/backends/qualcomm/builders/op_tanh.py @@ -26,7 +26,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_to.py b/backends/qualcomm/builders/op_to.py index 5fb016aef95..dc1062846ed 100644 --- a/backends/qualcomm/builders/op_to.py +++ b/backends/qualcomm/builders/op_to.py @@ -80,7 +80,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/builders/op_topk.py b/backends/qualcomm/builders/op_topk.py index 745cf7b9935..2b5d23268b9 100644 --- a/backends/qualcomm/builders/op_topk.py +++ b/backends/qualcomm/builders/op_topk.py @@ -33,7 +33,7 @@ def define_node( nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_transpose.py b/backends/qualcomm/builders/op_transpose.py index d29fc73084c..7fb02a2fb7c 100644 --- a/backends/qualcomm/builders/op_transpose.py +++ b/backends/qualcomm/builders/op_transpose.py @@ -28,7 +28,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) permute_node = input_node if QCOM_INSERTED_PERMUTE in node.meta else node input_tensor = self.get_tensor(input_node, permute_node) input_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/builders/op_unbind.py b/backends/qualcomm/builders/op_unbind.py index 8ca62e2a07b..1c505e6f4fd 100644 --- a/backends/qualcomm/builders/op_unbind.py +++ b/backends/qualcomm/builders/op_unbind.py @@ -27,7 +27,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_unsqueeze.py b/backends/qualcomm/builders/op_unsqueeze.py index 55790129462..f5cd7af3b2e 100644 --- a/backends/qualcomm/builders/op_unsqueeze.py +++ b/backends/qualcomm/builders/op_unsqueeze.py @@ -25,7 +25,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/builders/op_upsample_bilinear2d.py b/backends/qualcomm/builders/op_upsample_bilinear2d.py index 654fb934571..10dfe375fe0 100644 --- a/backends/qualcomm/builders/op_upsample_bilinear2d.py +++ b/backends/qualcomm/builders/op_upsample_bilinear2d.py @@ -26,7 +26,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_upsample_nearest2d.py b/backends/qualcomm/builders/op_upsample_nearest2d.py index c4b353fd3e9..4e9c4741ca2 100644 --- a/backends/qualcomm/builders/op_upsample_nearest2d.py +++ b/backends/qualcomm/builders/op_upsample_nearest2d.py @@ -26,7 +26,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] + input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/op_where.py b/backends/qualcomm/builders/op_where.py index ecac45a7a6f..94ee1b0e940 100644 --- a/backends/qualcomm/builders/op_where.py +++ b/backends/qualcomm/builders/op_where.py @@ -25,7 +25,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - conditional_input_node = node.args[0] + conditional_input_node = self.get_node(node.args[0]) conditional_input_tensor = self.get_tensor(conditional_input_node, node) conditional_input_tensor_wrapper = self.define_tensor( conditional_input_node, @@ -35,7 +35,7 @@ def define_node( nodes_to_wrappers, ) - true_input_node = node.args[1] + true_input_node = self.get_node(node.args[1]) true_input_tensor = self.get_tensor(true_input_node, node) true_input_tensor_wrapper = self.define_tensor( true_input_node, @@ -45,7 +45,7 @@ def define_node( nodes_to_wrappers, ) - false_input_node = node.args[2] + false_input_node = self.get_node(node.args[2]) false_input_tensor = self.get_tensor(false_input_node, node) false_input_tensor_wrapper = self.define_tensor( false_input_node, diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index d9eb188614c..7e5a779e748 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -12,7 +12,10 @@ from executorch.backends.qualcomm.builders import node_visitor from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader from executorch.backends.qualcomm.qnn_preprocess import QnnBackend -from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER +from executorch.backends.qualcomm.utils.constants import ( + QCOM_AXIS_ORDER, + QCOM_BYPASS_NODE, +) from executorch.exir.backend.backend_details import CompileSpec from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( @@ -46,7 +49,6 @@ def __init__( skip_node_op_set: set = None, ): self.node_visitors = node_visitor.get_node_visitors(edge_program) - self.skip_node_op_set = skip_node_op_set self.skip_node_id_set = skip_node_id_set self.nodes_to_wrappers = defaultdict(dict) @@ -70,6 +72,8 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool: node.target in allow_list_operator # bypass if custom op appears or OpContextLoader.namespace == node.target.namespace + # bypass dequantize op for parameters & buffers + or node.meta.get(QCOM_BYPASS_NODE, False) ): return True diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 338209fcd4a..74c85b773c2 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1535,8 +1535,11 @@ def test_qnn_backend_elu(self): def test_qnn_backend_embedding(self): module = Embedding() # noqa: F405 sample_input = (torch.Tensor([[1, 2, 4, 5], [4, 3, 2, 9]]).to(torch.int32),) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + quant_dtype = [QuantDtype.use_8a8w, QuantDtype.use_16a4w] + for i, qdtype in enumerate(quant_dtype): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input, quant_dtype=qdtype) + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_equal(self): test_comb = [ @@ -2505,7 +2508,7 @@ def test_qnn_backend_profile_op(self): module, sample_input, expected_partitions=1, - expected_profile_events=34, + expected_profile_events=30, ) def test_qnn_backend_shared_buffer(self): @@ -2527,6 +2530,9 @@ def test_qnn_backend_shared_buffer(self): ) def test_qnn_backend_online_prepare(self): + if self.enable_x86_64: + self.skipTest("online prepare is not supported on host machine") + backend_options = generate_htp_compiler_spec(use_fp16=True) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( soc_model=self.chipset_table[TestQNN.model], @@ -3150,7 +3156,7 @@ def test_qnn_backend_profile_op(self): module, sample_input, expected_partitions=1, - expected_profile_events=35, + expected_profile_events=30, ) def test_qnn_backend_shared_buffer(self): @@ -3173,6 +3179,9 @@ def test_qnn_backend_shared_buffer(self): ) def test_qnn_backend_online_prepare(self): + if self.enable_x86_64: + self.skipTest("online prepare is not supported on host machine") + backend_options = generate_htp_compiler_spec(use_fp16=False) TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( soc_model=self.chipset_table[TestQNN.model], @@ -3300,17 +3309,17 @@ def test_qnn_backend_dump_context_from_pte(self): def test_qnn_backend_draw_graph(self): golden_data = """digraph test { rankdir=TB - aten_convolution_default_0 [label=< + aten_convolution_default_1_0 [label=< - +
name: aten_convolution_default_0
name: aten_convolution_default_1_0
data_type: Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
dims: [1, 28, 28, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET
> color=black fillcolor=transparent shape=box style=rounded] - aten_relu_default_0 [label=< + aten_relu_default_1_0 [label=< - + @@ -3340,17 +3349,33 @@ def test_qnn_backend_draw_graph(self):
name: aten_relu_default_0
name: aten_relu_default_1_0
data_type: Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
dims: [1, 28, 28, 32]
dims: [32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET
> color=black fillcolor=transparent shape=box style=rounded] - aten_convolution_default_1_0 [label=< + b__frozen_param0_0 [label=< - + + + + + +
name: aten_convolution_default_1_0
name: b__frozen_param0_0
data_type: Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
dims: [3, 3, 32, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET
> color=black fillcolor=transparent shape=box style=rounded] + b__frozen_param1_0 [label=< + + + + + + +
name: b__frozen_param1_0
data_type: Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_32
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
dims: [32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET
> color=black fillcolor=transparent shape=box style=rounded] + aten_convolution_default_0 [label=< + +
name: aten_convolution_default_0
data_type: Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
dims: [1, 28, 28, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET
> color=black fillcolor=transparent shape=box style=rounded] - aten_relu_default_1_0 [label=< + aten_relu_default_0 [label=< - + @@ -3364,14 +3389,6 @@ def test_qnn_backend_draw_graph(self):
name: aten_relu_default_1_0
name: aten_relu_default_0
data_type: Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
dims: [1, 28, 28, 32]
dims: [1, 28, 28, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET
> color=black fillcolor=transparent shape=box style=rounded] - output_quantized_decomposed_dequantize_per_tensor_tensor_0 [label=< - - - - - - -
name: output_quantized_decomposed_dequantize_per_tensor_tensor_0
data_type: Qnn_DataType_t.QNN_DATATYPE_FLOAT_32
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ
dims: [1, 32, 28, 28]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED
> color=black fillcolor=transparent shape=box style=rounded] input_0_x_0 [label=< @@ -3380,36 +3397,27 @@ def test_qnn_backend_draw_graph(self):
name: input_0_x_0
dims: [1, 32, 28, 28]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED
> color=black fillcolor=transparent shape=box style=rounded] - b__frozen_param0_0 [label=< - - - - - - -
name: b__frozen_param0_0
data_type: Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
dims: [3, 3, 32, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET
> color=black fillcolor=transparent shape=box style=rounded] - b__frozen_param1_0 [label=< + output_quantized_decomposed_dequantize_per_tensor_tensor_0 [label=< - - - - - + + + + +
name: b__frozen_param1_0
data_type: Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_32
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
dims: [32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET
name: output_quantized_decomposed_dequantize_per_tensor_tensor_0
data_type: Qnn_DataType_t.QNN_DATATYPE_FLOAT_32
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ
dims: [1, 32, 28, 28]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED
> color=black fillcolor=transparent shape=box style=rounded] - quantized_decomposed_quantize_per_tensor_default_0 -> aten_convolution_default_0 - input_0_x_0 -> quantized_decomposed_quantize_per_tensor_default_0 - b__frozen_param0_0 -> aten_convolution_default_0 - b__frozen_param1_0 -> aten_convolution_default_0 - aten_convolution_default_0 -> aten_relu_default_0 quantized_decomposed_quantize_per_tensor_default_0 -> aten_convolution_default_1_0 + input_0_x_0 -> quantized_decomposed_quantize_per_tensor_default_0 b__frozen_param2_0 -> aten_convolution_default_1_0 b__frozen_param3_0 -> aten_convolution_default_1_0 aten_convolution_default_1_0 -> aten_relu_default_1_0 + quantized_decomposed_quantize_per_tensor_default_0 -> aten_convolution_default_0 + b__frozen_param0_0 -> aten_convolution_default_0 + b__frozen_param1_0 -> aten_convolution_default_0 + aten_convolution_default_0 -> aten_relu_default_0 aten_relu_default_0 -> aten_add_tensor_0 aten_relu_default_1_0 -> aten_add_tensor_0 aten_add_tensor_0 -> output_quantized_decomposed_dequantize_per_tensor_tensor_0 - } - """ + }""" module = DrawGraphModel() # noqa: F405 sample_input = (torch.randn(1, 32, 28, 28),) module = self.get_qdq_module(module, sample_input) diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 695c846de05..0b34290f4c2 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -15,6 +15,7 @@ import numpy as np import torch from executorch import exir +from executorch.backends.qualcomm._passes.utils import dq_ops from executorch.backends.qualcomm.qnn_preprocess import QnnBackend from executorch.backends.qualcomm.quantizer.quantizer import ModuleQConfig, QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset @@ -298,7 +299,7 @@ def validate_profile(): target_time_scale=TimeScale.CYCLES, ) self.assertTrue( - len(inspector.to_dataframe().index) == expected_profile_events + len(inspector.to_dataframe().index) >= expected_profile_events ) def validate_intermediate_tensor(): @@ -583,7 +584,7 @@ def get_converted_sgd_trained_module( optimizer.zero_grad() loss.backward() optimizer.step() - return torch.ao.quantization.quantize_pt2e.convert_pt2e(prepared) + return convert_pt2e(prepared) def split_graph(self, division: int): class SplitGraph(ExportPass): @@ -595,6 +596,10 @@ def __init__(self, division): super().__init__() self.division = division + def _is_legit_node(self, node): + # skip dq_ops for frozen_params + return node.op == "call_function" and node.target not in dq_ops + def _insert_clone( self, graph_module: torch.fx.GraphModule ) -> torch.fx.GraphModule: @@ -609,9 +614,11 @@ def _insert_clone( # Insert clone op to split model based on the shares num_graph_nodes = 0 for node in graph_module.graph.nodes: - num_graph_nodes += 1 if node.op == "call_function" else 0 + if not self._is_legit_node(node): + continue - if num_graph_nodes % shares != 0 or node.op != "call_function": + num_graph_nodes += 1 + if num_graph_nodes % shares != 0: continue with graph_module.graph.inserting_after(node): diff --git a/backends/qualcomm/utils/constants.py b/backends/qualcomm/utils/constants.py index ce917bf4115..a4a087287a4 100644 --- a/backends/qualcomm/utils/constants.py +++ b/backends/qualcomm/utils/constants.py @@ -15,6 +15,7 @@ QCOM_BLOCK_SCALE_BITWIDTH = "block_scale_bitwidth" QCOM_BLOCK_SCALE_OFFSET = "block_scale_offset" QCOM_BLOCK_STORAGE_TYPE = "block_storage_type" +QCOM_BYPASS_NODE = "bypass_node" QCOM_DATA = "data" QCOM_DTYPE = "dtype" QCOM_ENCODING = "encoding" diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 3653cd3176f..e80b03a32c0 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -640,7 +640,8 @@ def prepare_subgm(subgm, subgm_name): for node in graph_module.graph.nodes: if node.op == "call_module": graph_module.set_submodule( - node.name, convert_pt2e(graph_module.get_submodule(node.name)) + node.name, + convert_pt2e(graph_module.get_submodule(node.name)), ) # canonicalize graph for lowering again graph_module, edge_prog_mgrs = _canonicalize_graph_with_lowered_module( diff --git a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt index 4d4f1c2e39d..b9b1ddc0f72 100644 --- a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt +++ b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt @@ -52,6 +52,7 @@ target_link_libraries( qnn_executorch_backend executorch_core extension_data_loader + extension_flat_tensor extension_module extension_tensor gflags diff --git a/examples/qualcomm/qaihub_scripts/llama/CMakeLists.txt b/examples/qualcomm/qaihub_scripts/llama/CMakeLists.txt index 16d91013349..4e44a1599b1 100644 --- a/examples/qualcomm/qaihub_scripts/llama/CMakeLists.txt +++ b/examples/qualcomm/qaihub_scripts/llama/CMakeLists.txt @@ -39,6 +39,7 @@ target_link_libraries( qnn_executorch_backend executorch_core extension_data_loader + extension_flat_tensor extension_module extension_tensor gflags @@ -87,6 +88,7 @@ target_link_libraries( qnn_executorch_backend executorch_core extension_data_loader + extension_flat_tensor extension_module extension_tensor gflags diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt b/examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt index ff22f08cd09..5b63a6678fc 100644 --- a/examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt +++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt @@ -23,6 +23,7 @@ target_link_libraries( qnn_executorch_backend executorch_core extension_data_loader + extension_flat_tensor extension_module extension_tensor gflags diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index d8dab88e998..b5ac77c3d1f 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -251,8 +251,8 @@ def qat_train(ori_model, captured_model, quantizer, dataset): loss.backward() optimizer.step() - return torch.ao.quantization.quantize_pt2e.convert_pt2e( - torch.ao.quantization.move_exported_model_to_eval(annotated_model) + return convert_pt2e( + torch.ao.quantization.move_exported_model_to_eval(annotated_model), ) diff --git a/exir/program/_program.py b/exir/program/_program.py index f24807e253d..155591dd1af 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1283,7 +1283,7 @@ def to_edge_transform_and_lower( @experimental( """ - This is an experimental API which overloads to_edge by preserving specified ops to not be decomposed. + This is an experimental API which overloads to_edge by preserving specified ops to not be decomposed. This function will be combined with to_edge in the future. """ )