diff --git a/backends/qualcomm/_passes/annotate_quant_attrs.py b/backends/qualcomm/_passes/annotate_quant_attrs.py
index b4f14fc28cd..d9ef9cb691d 100644
--- a/backends/qualcomm/_passes/annotate_quant_attrs.py
+++ b/backends/qualcomm/_passes/annotate_quant_attrs.py
@@ -7,10 +7,8 @@
from typing import Any, Dict
import torch
-from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
+from executorch.backends.qualcomm.builders.utils import get_parameter
from executorch.backends.qualcomm.utils.constants import (
- QCOM_AXIS,
- QCOM_BLOCK_SIZE,
QCOM_DTYPE,
QCOM_ENCODING,
QCOM_QUANT_ATTRS,
@@ -18,11 +16,8 @@
QCOM_QUANT_MIN,
QCOM_REQUANTIZE,
QCOM_SCALE,
- QCOM_SCALES,
QCOM_ZERO_POINT,
- QCOM_ZERO_POINTS,
)
-from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from .utils import dq_ops, get_quant_attrs, q_ops
@@ -101,43 +96,9 @@ def _annotate_requant(self, n):
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})
n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs
- # Dequant all the fold_quant parameters back to fp32.
- # If an operation is not supported by QNN and got fallback, it will expect a fp32 param.
- def _dequant_fold_params(self, n, quant_attrs, param):
- if quant_attrs[QCOM_ENCODING] in [
- exir_ops.edge.quantized_decomposed.dequantize_per_channel.default
- ]:
- dim, axis = param.dim(), quant_attrs[QCOM_AXIS]
- scales = self._expand(quant_attrs[QCOM_SCALES], dim, axis)
- offsets = self._expand(quant_attrs[QCOM_ZERO_POINTS], dim, axis)
- param = param.sub(offsets).mul(scales).to(torch.float32).contiguous()
- elif quant_attrs[QCOM_ENCODING] in [
- exir_ops.edge.pt2e_quant.dequantize_affine.default
- ]:
- param = torch.ops.pt2e_quant.dequantize_affine(
- param,
- block_size=quant_attrs[QCOM_BLOCK_SIZE],
- scale=quant_attrs[QCOM_SCALE],
- zero_point=quant_attrs[QCOM_ZERO_POINT],
- input_dtype=quant_attrs[QCOM_DTYPE],
- quant_min=quant_attrs[QCOM_QUANT_MIN],
- quant_max=quant_attrs[QCOM_QUANT_MAX],
- output_dtype=torch.float32,
- )
- else:
- scale = quant_attrs[QCOM_SCALE]
- offset = quant_attrs[QCOM_ZERO_POINT]
- param = param.sub(offset).mul(scale).to(torch.float32).contiguous()
-
- set_parameter(param, n.args[0], self.edge_program)
- n.args[0].meta["val"] = param
-
def _annotate_quant_attrs(
self, graph_module: torch.fx.GraphModule
) -> torch.fx.GraphModule:
- # Keep track of const params that has been dequant, so it does not get
- # dequant multiple times if the const param has more than 1 user
- visited_const_param = set()
for n in graph_module.graph.nodes:
self._annotate_requant(n)
# With fold_quant enabled, check if the input of dq op is quantized param.
@@ -149,10 +110,6 @@ def _annotate_quant_attrs(
quant_attrs = get_quant_attrs(self.edge_program, n)
self._annotate_source_nodes(n, quant_attrs)
- if param is not None and n.args[0] not in visited_const_param:
- visited_const_param.add(n.args[0])
- self._dequant_fold_params(n, quant_attrs, param)
-
return graph_module
def call(self, graph_module: torch.fx.GraphModule):
diff --git a/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py b/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py
index 72dc29c2880..dabe0243a47 100644
--- a/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py
+++ b/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py
@@ -5,10 +5,8 @@
# LICENSE file in the root directory of this source tree.
import torch
-import torch.nn as nn
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE
-from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from .utils import copy_meta
@@ -23,16 +21,43 @@ class ConvertConv1dToConv2d(ExportPass):
def __init__(self, edge_program: torch.export.ExportedProgram):
super(ConvertConv1dToConv2d, self).__init__()
self.edge_program = edge_program
+ self.conv_op_map = {
+ torch.ops.aten.conv1d.default: torch.ops.aten.conv2d.default,
+ torch.ops.aten.conv_transpose1d.default: torch.ops.aten.conv_transpose2d.input,
+ }
+
+ def append_qdq(
+ self,
+ graph_module: torch.fx.GraphModule,
+ node: torch.fx.Node,
+ qdq_node: torch.fx.Node,
+ ):
+ q_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
+ dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
+ if qdq_node.target not in {q_op, dq_op}:
+ return node
+
+ with graph_module.graph.inserting_after(node):
+ q_args = (node, *qdq_node.args[1:])
+ q_node = graph_module.graph.create_node("call_function", q_op, q_args)
+ q_node.meta = copy_meta(node.meta)
+ q_node.meta["val"] = q_node.meta["val"].to(q_args[-1])
+ with graph_module.graph.inserting_after(q_node):
+ dq_args = (q_node, *qdq_node.args[1:])
+ dq_node = graph_module.graph.create_node(
+ "call_function", dq_op, dq_args
+ )
+ dq_node.meta = copy_meta(node.meta)
+
+ return dq_node
def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
- conv_op = exir_ops.edge.aten.convolution.default
for node in graph.nodes:
- if node.target == conv_op and node.meta["val"].dim() == 3:
-
+ if node.target in self.conv_op_map:
input_node = node.args[0]
with graph_module.graph.inserting_after(input_node):
- unsqueeze_op = exir_ops.edge.aten.unsqueeze_copy.default
+ unsqueeze_op = torch.ops.aten.unsqueeze_copy.default
unsqueeze_node = graph.create_node(
"call_function",
unsqueeze_op,
@@ -44,52 +69,88 @@ def call(self, graph_module: torch.fx.GraphModule):
unsqueeze_node.meta = copy_meta(
input_node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
)
+ qdq_node_after_unsqueeze = self.append_qdq(
+ graph_module=graph_module,
+ node=unsqueeze_node,
+ qdq_node=input_node,
+ )
- with graph_module.graph.inserting_after(unsqueeze_node):
-
- filter_node = node.args[1]
+ with graph_module.graph.inserting_after(qdq_node_after_unsqueeze):
+ filter_arg = node.args[1]
+ filter_node = (
+ filter_arg
+ if filter_arg.op == "placeholder"
+ else node.args[1].args[0]
+ )
filter_node.meta["val"] = (
filter_node.meta["val"].unsqueeze(2).contiguous()
)
- filter_tensor = get_parameter(filter_node, self.edge_program)
- # Ensure tensor is nn.Parameter type, so program does not fail during edge_program._validate()
- filter_tensor = nn.Parameter(filter_tensor.unsqueeze(2))
- set_parameter(filter_tensor, filter_node, self.edge_program)
+ filter_tensor = get_parameter(
+ filter_node, self.edge_program
+ ).unsqueeze(2)
+ set_parameter(
+ (
+ torch.nn.Parameter(filter_tensor)
+ if filter_tensor.dtype == torch.float
+ else filter_tensor
+ ),
+ filter_node,
+ self.edge_program,
+ )
+ num_args = len(node.args)
bias_node = node.args[2]
- stride = [1] + node.args[3]
- padding = [0] + node.args[4]
- dilation = [1] + node.args[5]
- transpose = node.args[6]
- output_padding = [0] + node.args[7]
- groups = node.args[8]
-
- conv2d_node = graph.create_node(
- "call_function",
- conv_op,
- (
- unsqueeze_node,
- filter_node,
+ stride = [1] + node.args[3] if num_args > 3 else [1, 1]
+ padding = [0] + node.args[4] if num_args > 4 else [0, 0]
+ if node.target == torch.ops.aten.conv1d.default:
+ dilation = [1] + node.args[5] if num_args > 5 else [1, 1]
+ groups = node.args[6] if num_args > 5 else 1
+ conv_args = (
+ qdq_node_after_unsqueeze,
+ node.args[1],
bias_node,
stride,
padding,
dilation,
- transpose,
+ groups,
+ )
+ else:
+ output_padding = (
+ [0] + node.args[5] if num_args > 5 else [0, 0]
+ )
+ groups = node.args[6] if num_args > 6 else 1
+ dilation = [1] + node.args[7] if num_args > 7 else [1, 1]
+ conv_args = (
+ qdq_node_after_unsqueeze,
+ node.args[1],
+ bias_node,
+ stride,
+ padding,
output_padding,
groups,
- ),
+ dilation,
+ )
+ conv2d_node = graph.create_node(
+ "call_function",
+ self.conv_op_map[node.target],
+ conv_args,
)
conv2d_node.meta = copy_meta(
node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)}
)
+ qdq_node_after_conv2d = self.append_qdq(
+ graph_module=graph_module,
+ node=conv2d_node,
+ qdq_node=list(node.users)[0],
+ )
- with graph_module.graph.inserting_after(conv2d_node):
- squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
+ with graph_module.graph.inserting_after(qdq_node_after_conv2d):
+ squeeze_op = torch.ops.aten.squeeze_copy.dims
squeeze_node = graph.create_node(
"call_function",
squeeze_op,
(
- conv2d_node,
+ qdq_node_after_conv2d,
[2],
),
)
@@ -102,8 +163,10 @@ def call(self, graph_module: torch.fx.GraphModule):
QCOM_REQUANTIZE
]
conv2d_node.meta.pop(QCOM_REQUANTIZE, None)
+
for user in node.users.copy():
user.replace_input_with(node, squeeze_node)
+
graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
diff --git a/backends/qualcomm/_passes/expand_broadcast_tensor_shape.py b/backends/qualcomm/_passes/expand_broadcast_tensor_shape.py
index 829b3757e06..4fe87604fc1 100644
--- a/backends/qualcomm/_passes/expand_broadcast_tensor_shape.py
+++ b/backends/qualcomm/_passes/expand_broadcast_tensor_shape.py
@@ -9,6 +9,8 @@
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import dead_code_elimination_pass
+from .utils import dq_ops
+
class ExpandBroadcastTensorShape(ExportPass):
"""
@@ -45,9 +47,13 @@ def traverse_broadcast_node(self, graph_module: torch.fx.GraphModule):
exir_ops.edge.aten.view_copy.default,
(arg, tuple(new_rank)),
)
+ # try skip dq_ops to get correct param node if applicable
+ arg_meta = (
+ arg.args[0].meta if arg.target in dq_ops else arg.meta
+ )
# meta needs to be copied elementwisely for fake-tensor
# to be updated correctly and not affect meta of arg
- for k, v in arg.meta.items():
+ for k, v in arg_meta.items():
reshape_node.meta[k] = v
reshape_node.meta["val"] = reshape_node.meta["val"].reshape(
new_rank
diff --git a/backends/qualcomm/_passes/fold_qdq.py b/backends/qualcomm/_passes/fold_qdq.py
index bc17b2fae1f..accf66d4c35 100644
--- a/backends/qualcomm/_passes/fold_qdq.py
+++ b/backends/qualcomm/_passes/fold_qdq.py
@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
+from executorch.backends.qualcomm.builders.utils import is_parameter
+from executorch.backends.qualcomm.utils.constants import QCOM_BYPASS_NODE
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import dead_code_elimination_pass
@@ -16,23 +18,38 @@ class FoldQDQ(ExportPass):
Erase QDQ pattern.
"""
- def __init__(self):
+ def __init__(self, edge_program: torch.export.ExportedProgram, force_fold=False):
super(FoldQDQ, self).__init__()
+ self.edge_program = edge_program
+ self.force_fold = force_fold
- def _fold(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
+ def _annotate_bypass(self, node):
+ node.meta[QCOM_BYPASS_NODE] = True
+ for arg in node.args:
+ if isinstance(arg, torch.fx.Node) and arg.op == "call_function":
+ self._annotate_bypass(arg)
+
+ def _fold_dq(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
# remove dq
for n in graph_module.graph.nodes:
user_list = list(n.users.keys())
if n.target not in dq_ops:
continue
- for user_n in user_list:
- user_n.replace_input_with(n, n.args[0])
- graph_module.graph.erase_node(n)
+ # skip parameters & buffers
+ if not self.force_fold and is_parameter(n.args[0], self.edge_program):
+ self._annotate_bypass(n)
+ else:
+ for user_n in user_list:
+ user_n.replace_input_with(n, n.args[0])
+ graph_module.graph.erase_node(n)
+
+ def _fold_q(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
# remove q
for n in graph_module.graph.nodes:
if n.target not in q_ops:
continue
+
to_be_removed = [n]
source_n = n.args[0]
@@ -57,7 +74,8 @@ def _fold(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
graph_module.graph.erase_node(n)
def call(self, graph_module: torch.fx.GraphModule):
- self._fold(graph_module)
+ self._fold_dq(graph_module)
+ self._fold_q(graph_module)
graph_module.recompile()
dead_code_elimination_pass(graph_module)
return PassResult(graph_module, True)
diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py
index c98f27db120..63c303eb689 100644
--- a/backends/qualcomm/_passes/qnn_pass_manager.py
+++ b/backends/qualcomm/_passes/qnn_pass_manager.py
@@ -200,6 +200,9 @@ def transform_for_export_pipeline(self, exported_program: ExportedProgram):
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
self.add_pass(DecomposeExpM1())
+ # this pass will rewrite state_dict, it needs to be accomplished before
+ # to_edge_transform_and_lower
+ self.add_pass(ConvertConv1dToConv2d(exported_program))
self.add_pass(ConvertSquareToPow())
self.add_pass(LiftConstantScalarOperands())
self._transform(exported_program.graph_module)
@@ -207,6 +210,7 @@ def transform_for_export_pipeline(self, exported_program: ExportedProgram):
return ep
def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram):
+ self.add_pass(FoldQDQ(exported_program, force_fold=True))
self.add_pass(InsertRequantize())
self.add_pass(InsertIOQDQ(exported_program))
self.add_pass(LayoutTransform(exported_program, insert_permute=True))
diff --git a/backends/qualcomm/_passes/replace_index_put_input.py b/backends/qualcomm/_passes/replace_index_put_input.py
index dcdf2bb3a7f..93ee21bfc7c 100644
--- a/backends/qualcomm/_passes/replace_index_put_input.py
+++ b/backends/qualcomm/_passes/replace_index_put_input.py
@@ -33,7 +33,8 @@ def call(self, graph_module: torch.fx.GraphModule):
copy_node := list(node.users)[0]
) and copy_node.target == exir_ops.edge.aten.copy.default:
m_buffer_node = copy_node.args[0]
- bad_frozen_node = node.args[0]
+ dq_node = node.args[0]
+ bad_frozen_node = dq_node.args[0]
if QCOM_QUANT_ATTRS in bad_frozen_node.meta:
m_buffer_node.meta[QCOM_QUANT_ATTRS] = bad_frozen_node.meta[
QCOM_QUANT_ATTRS
@@ -43,8 +44,8 @@ def call(self, graph_module: torch.fx.GraphModule):
m_buffer_node.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING]
]
)
- with graph.inserting_after(bad_frozen_node):
- node.replace_input_with(bad_frozen_node, m_buffer_node)
+ with graph.inserting_after(dq_node):
+ node.replace_input_with(dq_node, m_buffer_node)
else:
continue
diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py
index 46d9e0cde76..10dcbb07aac 100755
--- a/backends/qualcomm/_passes/utils.py
+++ b/backends/qualcomm/_passes/utils.py
@@ -103,7 +103,6 @@ def get_passes_dependency_for_capture_program():
AnnotateStack: [RemoveRedundancy],
AnnotateUnbind: [RemoveRedundancy],
ConvertBmmToMatmul: [RecomposePixelUnshuffle],
- ConvertConv1dToConv2d: [FoldQDQ],
ConvertUpsampleBicubicWithBilinear: [RemoveRedundancy],
DecomposeAny: [RemoveRedundancy],
DecomposeLinalgVectorNorm: [RemoveRedundancy],
diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md
index 783a53dd645..22f0852941c 100644
--- a/backends/qualcomm/builders/README.md
+++ b/backends/qualcomm/builders/README.md
@@ -227,7 +227,7 @@ Now, we can start to fill in function body step by step:
2. Define input gamma / beta tensors:
```python
- weight_node = node.args[2]
+ weight_node = self.get_node(node.args[2])
weight_tensor = get_parameter(weight_node, self.edge_program)
weight_tensor_wrapper = self.define_tensor(
weight_node,
@@ -237,7 +237,7 @@ Now, we can start to fill in function body step by step:
nodes_to_wrappers,
)
- bias_node = node.args[3]
+ bias_node = self.get_node(node.args[3])
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py
index 7965a30caea..5e9520d4c05 100644
--- a/backends/qualcomm/builders/node_visitor.py
+++ b/backends/qualcomm/builders/node_visitor.py
@@ -11,6 +11,7 @@
import numpy as np
import torch
+from executorch.backends.qualcomm._passes.utils import dq_ops
from executorch.backends.qualcomm.utils.constants import (
QCOM_AXIS,
QCOM_AXIS_ORDER,
@@ -18,7 +19,6 @@
QCOM_BLOCK_SCALE_BITWIDTH,
QCOM_BLOCK_SCALE_OFFSET,
QCOM_BLOCK_SCALES,
- QCOM_BLOCK_SIZE,
QCOM_BLOCK_STORAGE_TYPE,
QCOM_DTYPE,
QCOM_ENCODING,
@@ -95,6 +95,19 @@ def __init__(
self.edge_program = edge_program
self.enable_tensor_dump = enable_tensor_dump
+ def get_node(self, node):
+ """
+ Utility to skip dequantize node for frozen param
+ """
+ return node.args[0] if node is not None and node.target in dq_ops else node
+
+ def get_first_user(self, node):
+ """
+ Utility to skip dequantize user for frozen param
+ """
+ user_0 = list(node.users)[0]
+ return user_0 if user_0.target not in dq_ops else self.get_first_user(user_0)
+
def get_tensor(self, input_node, op_node, idx=None):
"""
Get tensor value/shape with axis_order
@@ -142,7 +155,9 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
# symmetric quantization is required
scale_offset.append(PyQnnWrapper.Qnn_ScaleOffset_t(max_scale, 0))
- if "convolution" in list(node.users)[0].target.__name__:
+ # skip dequantize op, e.g. frozen_param -> dq -> conv2d
+ user_0 = self.get_first_user(node)
+ if "convolution" in user_0.target.__name__:
# OIHW (pytorch) -> HWIO (QNN)
quant_config[QCOM_AXIS] = 3
quant_config[QCOM_AXIS_ORDER] = (2, 3, 1, 0)
@@ -178,14 +193,11 @@ def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict):
PyQnnWrapper.Qnn_ScaleOffset_t(scales[i], -zero_points[i])
)
- user_0 = list(node.users)[0]
+ # skip dequantize op, e.g. frozen_param -> dq -> conv2d
+ user_0 = self.get_first_user(node)
# Memory layout of QNN conv weight always ends in Output. Like conv2d is HWIO
- if (
- "convolution" in user_0.target.__name__
- and list(node.users)[0].args[1] == node
- ):
+ if "convolution" in user_0.target.__name__:
quant_config[QCOM_AXIS] = 3
-
else:
quant_config[QCOM_AXIS] = quant_attrs[QCOM_AXIS]
@@ -256,33 +268,21 @@ def get_quant_encoding_conf(
def get_quant_tensor_value(
self, tensor: torch.Tensor, quant_attrs: Dict, quant_configs: Dict
) -> torch.Tensor:
- dtype = quant_configs[QCOM_DTYPE]
- if quant_attrs[QCOM_ENCODING] in PER_TENSOR_ENCODING:
+ # params should have been quantized by framework
+ # here we're handling constant operators like arange, full, etc.
+ if tensor.dtype == torch.float32:
+ assert quant_attrs[QCOM_ENCODING] in PER_TENSOR_ENCODING, (
+ f"unrecongnized quantization attribute detected {quant_attrs[QCOM_ENCODING]}",
+ )
scale = quant_attrs[QCOM_SCALE]
zero_point = quant_attrs[QCOM_ZERO_POINT]
- tensor = tensor.div(scale).add(zero_point).round().to(dtype)
- elif quant_attrs[QCOM_ENCODING] in PER_CHANNEL_ENCODING:
- scale = quant_attrs[QCOM_SCALES]
- zero_point = quant_attrs[QCOM_ZERO_POINTS]
- tensor = tensor.div(scale).add(zero_point).round().to(dtype)
- else: # per_block
- if axis_order := quant_configs.get(QCOM_AXIS_ORDER, None):
- origin_order = tuple(
- axis_order.index(x) for x in range(len(axis_order))
- )
- tensor = tensor.permute(origin_order)
- tensor = torch.ops.pt2e_quant.quantize_affine(
- tensor,
- block_size=quant_attrs[QCOM_BLOCK_SIZE],
- scale=quant_attrs[QCOM_SCALE],
- zero_point=quant_attrs[QCOM_ZERO_POINT],
- output_dtype=dtype,
- quant_min=quant_attrs[QCOM_QUANT_MIN],
- quant_max=quant_attrs[QCOM_QUANT_MAX],
+ tensor = (
+ tensor.div(scale).add(zero_point).round().to(quant_configs[QCOM_DTYPE])
)
- if axis_order:
- tensor = tensor.permute(axis_order)
-
+ # Since we're using torch.int32 to store 16bit data
+ # need to make it compact here for QNN to correctly retrieve data
+ if quant_configs.get(QCOM_DTYPE) == torch.uint16:
+ tensor = tensor.to(torch.uint16)
# Make the backends access data correctly
if quant_configs.get(QCOM_BITWIDTH) == 4:
mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8)
diff --git a/backends/qualcomm/builders/op_abs.py b/backends/qualcomm/builders/op_abs.py
index 002ffe85208..2209ffc792c 100644
--- a/backends/qualcomm/builders/op_abs.py
+++ b/backends/qualcomm/builders/op_abs.py
@@ -35,7 +35,7 @@ def define_node(
)
abs_output_tensors = [output_tensor_wrapper]
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor_wrapper = self.define_tensor(
input_node,
node,
diff --git a/backends/qualcomm/builders/op_adaptive_avg_pool2d.py b/backends/qualcomm/builders/op_adaptive_avg_pool2d.py
index c944e1646e7..777e1f61ada 100644
--- a/backends/qualcomm/builders/op_adaptive_avg_pool2d.py
+++ b/backends/qualcomm/builders/op_adaptive_avg_pool2d.py
@@ -28,7 +28,7 @@ def define_node(
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_add.py b/backends/qualcomm/builders/op_add.py
index b5edfd7bb52..f8fb31fb725 100644
--- a/backends/qualcomm/builders/op_add.py
+++ b/backends/qualcomm/builders/op_add.py
@@ -37,7 +37,7 @@ def define_node(
add_input_tensors = []
for index in range(2):
- input_node = node.args[index]
+ input_node = self.get_node(node.args[index])
input_tensor = self.get_tensor(input_node, node)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
diff --git a/backends/qualcomm/builders/op_amax.py b/backends/qualcomm/builders/op_amax.py
index 099004a4bcf..62c17b8dfcd 100644
--- a/backends/qualcomm/builders/op_amax.py
+++ b/backends/qualcomm/builders/op_amax.py
@@ -28,7 +28,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_and.py b/backends/qualcomm/builders/op_and.py
index 44e6f2893f5..22b63e0d6ff 100644
--- a/backends/qualcomm/builders/op_and.py
+++ b/backends/qualcomm/builders/op_and.py
@@ -37,7 +37,7 @@ def define_node(
and_input_tensors = []
for index in range(2):
- input_node = node.args[index]
+ input_node = self.get_node(node.args[index])
input_tensor = self.get_tensor(input_node, node)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
diff --git a/backends/qualcomm/builders/op_argmin.py b/backends/qualcomm/builders/op_argmin.py
index 5630b02a5cc..fa3fad4a61b 100644
--- a/backends/qualcomm/builders/op_argmin.py
+++ b/backends/qualcomm/builders/op_argmin.py
@@ -27,7 +27,7 @@ def define_node(
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
op_wrapper_list = []
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
output_tensor = self.get_tensor(node, node)
argmin_inp_tensor_wrapper = self.define_tensor(
diff --git a/backends/qualcomm/builders/op_avg_pool2d.py b/backends/qualcomm/builders/op_avg_pool2d.py
index 394d4008587..f4762e8bb5a 100644
--- a/backends/qualcomm/builders/op_avg_pool2d.py
+++ b/backends/qualcomm/builders/op_avg_pool2d.py
@@ -28,7 +28,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_batch_norm.py b/backends/qualcomm/builders/op_batch_norm.py
index 9aed1401875..ec0a7c39348 100644
--- a/backends/qualcomm/builders/op_batch_norm.py
+++ b/backends/qualcomm/builders/op_batch_norm.py
@@ -40,14 +40,22 @@ def update_encoding(self, node: torch.fx.Node, tensor: torch.Tensor, eps):
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
# scale value equals to zero will cause failure in HTP
diff = max(abs(tensor.max()), abs(tensor.min())) + eps
- quant_attrs[QCOM_SCALE] = diff / quant_attrs[QCOM_QUANT_MAX]
+ quant_attrs[QCOM_SCALE] = (diff / quant_attrs[QCOM_QUANT_MAX]).item()
+
+ def try_dequantize(self, node: torch.fx.Node, tensor: torch.Tensor):
+ if tensor.dtype == torch.float:
+ return tensor
+
+ scale = node.meta[QCOM_QUANT_ATTRS][QCOM_SCALE]
+ offset = node.meta[QCOM_QUANT_ATTRS][QCOM_ZERO_POINT]
+ return tensor.sub(offset).mul(scale).to(torch.float32).contiguous()
def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
eps = 1e-9
@@ -78,9 +86,12 @@ def define_node(
batch_norm_output_tensors = [output_tensor_wrapper]
n_feature = output_tensor.shape[-1 if QCOM_AXIS_ORDER in node.meta else 1]
- filter_node = node.args[1]
+ filter_node = self.get_node(node.args[1])
if filter_node is not None:
- filter_tensor = get_parameter(filter_node, self.edge_program)
+ # dequantize here for post-process
+ filter_tensor = self.try_dequantize(
+ filter_node, get_parameter(filter_node, self.edge_program)
+ )
else:
# 'graph', 'name', 'op', 'target', 'args', and 'kwargs'
filter_node = torch.fx.Node(
@@ -110,9 +121,12 @@ def define_node(
)
batch_norm_input_tensors.append(filter_tensor_wrapper)
- bias_node = node.args[2]
+ bias_node = self.get_node(node.args[2])
if bias_node is not None:
- bias_tensor = get_parameter(bias_node, self.edge_program)
+ # dequantize here for post-process
+ bias_tensor = self.try_dequantize(
+ bias_node, get_parameter(bias_node, self.edge_program)
+ )
amount = (filter_tensor * mean_tensor) / torch.sqrt(var_tensor + eps)
bias_tensor = bias_tensor - amount
self.update_encoding(bias_node, bias_tensor, eps)
diff --git a/backends/qualcomm/builders/op_bmm.py b/backends/qualcomm/builders/op_bmm.py
index 46fbff1cc7e..d473d085490 100644
--- a/backends/qualcomm/builders/op_bmm.py
+++ b/backends/qualcomm/builders/op_bmm.py
@@ -27,7 +27,7 @@ def define_node(
) -> PyQnnWrapper.PyQnnOpWrapper:
bmm_input_tensors = []
for index in range(2):
- input_node = node.args[index]
+ input_node = self.get_node(node.args[index])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
diff --git a/backends/qualcomm/builders/op_cat.py b/backends/qualcomm/builders/op_cat.py
index 7f160856390..09f99396589 100644
--- a/backends/qualcomm/builders/op_cat.py
+++ b/backends/qualcomm/builders/op_cat.py
@@ -32,7 +32,7 @@ def define_node(
list_of_tensor_wrappers = []
for tensor_input in list_of_tensors:
- input_tensor = self.get_tensor(tensor_input, node)
+ input_tensor = self.get_tensor(self.get_node(tensor_input), node)
list_of_tensor_wrappers.append(
self.define_tensor(
tensor_input,
diff --git a/backends/qualcomm/builders/op_ceil.py b/backends/qualcomm/builders/op_ceil.py
index 19fe14d6392..f0a43846d11 100644
--- a/backends/qualcomm/builders/op_ceil.py
+++ b/backends/qualcomm/builders/op_ceil.py
@@ -25,7 +25,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_clamp.py b/backends/qualcomm/builders/op_clamp.py
index 0f9a9ffa196..e80c99db352 100644
--- a/backends/qualcomm/builders/op_clamp.py
+++ b/backends/qualcomm/builders/op_clamp.py
@@ -27,7 +27,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_conv2d.py b/backends/qualcomm/builders/op_conv2d.py
index c019a835223..5a168ca103a 100644
--- a/backends/qualcomm/builders/op_conv2d.py
+++ b/backends/qualcomm/builders/op_conv2d.py
@@ -104,7 +104,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
assert (
input_tensor.dim() == 4
@@ -117,7 +117,7 @@ def define_node(
nodes_to_wrappers,
)
- filter_node = node.args[1]
+ filter_node = self.get_node(node.args[1])
filter_tensor = get_parameter(filter_node, self.edge_program)
# weight of pytorch OIHW(conv2d) | IOHW(conv_transpose2d), yet QNN is HWIO
is_transpose_conv = cast(bool, node.args[6])
@@ -133,7 +133,7 @@ def define_node(
conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper]
if node.args[2] is not None:
- bias_node = node.args[2]
+ bias_node = self.get_node(node.args[2])
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
diff --git a/backends/qualcomm/builders/op_cos.py b/backends/qualcomm/builders/op_cos.py
index 589bf3ef88e..69c0d40a026 100644
--- a/backends/qualcomm/builders/op_cos.py
+++ b/backends/qualcomm/builders/op_cos.py
@@ -25,7 +25,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_cum_sum.py b/backends/qualcomm/builders/op_cum_sum.py
index f62485bc519..dceaea83345 100644
--- a/backends/qualcomm/builders/op_cum_sum.py
+++ b/backends/qualcomm/builders/op_cum_sum.py
@@ -37,7 +37,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_depth_to_space.py b/backends/qualcomm/builders/op_depth_to_space.py
index 56c57b4bd5e..357b7a81039 100644
--- a/backends/qualcomm/builders/op_depth_to_space.py
+++ b/backends/qualcomm/builders/op_depth_to_space.py
@@ -28,7 +28,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_div.py b/backends/qualcomm/builders/op_div.py
index ce3f96abc7f..399e914e290 100644
--- a/backends/qualcomm/builders/op_div.py
+++ b/backends/qualcomm/builders/op_div.py
@@ -37,7 +37,7 @@ def define_node(
div_input_tensors = []
for index in range(2):
- input_node = node.args[index]
+ input_node = self.get_node(node.args[index])
input_tensor = self.get_tensor(input_node, node)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
diff --git a/backends/qualcomm/builders/op_elu.py b/backends/qualcomm/builders/op_elu.py
index f9cc089c7bb..f0ac422f4b8 100644
--- a/backends/qualcomm/builders/op_elu.py
+++ b/backends/qualcomm/builders/op_elu.py
@@ -28,7 +28,7 @@ def define_node(
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
# tensor input
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
diff --git a/backends/qualcomm/builders/op_embedding.py b/backends/qualcomm/builders/op_embedding.py
index 5b0d1600393..ba5b1a02077 100644
--- a/backends/qualcomm/builders/op_embedding.py
+++ b/backends/qualcomm/builders/op_embedding.py
@@ -28,7 +28,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- weight_node = node.args[0]
+ weight_node = self.get_node(node.args[0])
weight_tensor = get_parameter(weight_node, self.edge_program)
weight_tensor_wrapper = self.define_tensor(
weight_node,
diff --git a/backends/qualcomm/builders/op_eq.py b/backends/qualcomm/builders/op_eq.py
index 855c5e13be6..6f33ea78bd1 100644
--- a/backends/qualcomm/builders/op_eq.py
+++ b/backends/qualcomm/builders/op_eq.py
@@ -37,7 +37,7 @@ def define_node(
input_tensors = []
for index in range(2):
- input_node = node.args[index]
+ input_node = self.get_node(node.args[index])
input_tensor = self.get_tensor(input_node, node)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
diff --git a/backends/qualcomm/builders/op_exp.py b/backends/qualcomm/builders/op_exp.py
index 8c4794c9725..f736dec85c2 100644
--- a/backends/qualcomm/builders/op_exp.py
+++ b/backends/qualcomm/builders/op_exp.py
@@ -26,7 +26,7 @@ def define_node(
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
# tensor input
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
diff --git a/backends/qualcomm/builders/op_expand.py b/backends/qualcomm/builders/op_expand.py
index c098ed00c94..31d248638ab 100644
--- a/backends/qualcomm/builders/op_expand.py
+++ b/backends/qualcomm/builders/op_expand.py
@@ -27,7 +27,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_ge.py b/backends/qualcomm/builders/op_ge.py
index 6784167aa5b..28a29829731 100644
--- a/backends/qualcomm/builders/op_ge.py
+++ b/backends/qualcomm/builders/op_ge.py
@@ -37,7 +37,7 @@ def define_node(
input_tensors = []
for index in range(2):
- input_node = node.args[index]
+ input_node = self.get_node(node.args[index])
input_tensor = self.get_tensor(input_node, node)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
diff --git a/backends/qualcomm/builders/op_gelu.py b/backends/qualcomm/builders/op_gelu.py
index c178740448e..02356a2eef5 100644
--- a/backends/qualcomm/builders/op_gelu.py
+++ b/backends/qualcomm/builders/op_gelu.py
@@ -26,7 +26,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_group_norm.py b/backends/qualcomm/builders/op_group_norm.py
index 26700216b53..a52569cfa7a 100644
--- a/backends/qualcomm/builders/op_group_norm.py
+++ b/backends/qualcomm/builders/op_group_norm.py
@@ -29,7 +29,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
@@ -39,7 +39,7 @@ def define_node(
nodes_to_wrappers,
)
- weight_node = node.args[1]
+ weight_node = self.get_node(node.args[1])
weight_tensor = get_parameter(weight_node, self.edge_program)
weight_tensor_wrapper = self.define_tensor(
weight_node,
@@ -49,7 +49,7 @@ def define_node(
nodes_to_wrappers,
)
- bias_node = node.args[2]
+ bias_node = self.get_node(node.args[2])
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
diff --git a/backends/qualcomm/builders/op_gt.py b/backends/qualcomm/builders/op_gt.py
index 6c311f42b7f..8c1ef3a600c 100644
--- a/backends/qualcomm/builders/op_gt.py
+++ b/backends/qualcomm/builders/op_gt.py
@@ -37,7 +37,7 @@ def define_node(
input_tensors = []
for index in range(2):
- input_node = node.args[index]
+ input_node = self.get_node(node.args[index])
input_tensor = self.get_tensor(input_node, node)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
diff --git a/backends/qualcomm/builders/op_hardsigmoid.py b/backends/qualcomm/builders/op_hardsigmoid.py
index 1acc08a387d..c30cae92f55 100644
--- a/backends/qualcomm/builders/op_hardsigmoid.py
+++ b/backends/qualcomm/builders/op_hardsigmoid.py
@@ -28,7 +28,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_hardswish.py b/backends/qualcomm/builders/op_hardswish.py
index ed28ff95f78..fb4d0a40515 100644
--- a/backends/qualcomm/builders/op_hardswish.py
+++ b/backends/qualcomm/builders/op_hardswish.py
@@ -26,7 +26,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_hardtanh.py b/backends/qualcomm/builders/op_hardtanh.py
index 68bafaaab8b..4025a060ff3 100644
--- a/backends/qualcomm/builders/op_hardtanh.py
+++ b/backends/qualcomm/builders/op_hardtanh.py
@@ -28,7 +28,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_index.py b/backends/qualcomm/builders/op_index.py
index ff039f9d7a8..fe6bf4262d8 100644
--- a/backends/qualcomm/builders/op_index.py
+++ b/backends/qualcomm/builders/op_index.py
@@ -28,7 +28,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_index_put.py b/backends/qualcomm/builders/op_index_put.py
index c317cc0a8b7..9c11a6ca891 100644
--- a/backends/qualcomm/builders/op_index_put.py
+++ b/backends/qualcomm/builders/op_index_put.py
@@ -20,7 +20,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
@@ -50,7 +50,7 @@ def define_node(
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
- value_node = node.args[2]
+ value_node = self.get_node(node.args[2])
value_tensor = self.get_tensor(value_node, node)
diff --git a/backends/qualcomm/builders/op_instance_norm.py b/backends/qualcomm/builders/op_instance_norm.py
index e7e7f14a944..828e89a97f2 100644
--- a/backends/qualcomm/builders/op_instance_norm.py
+++ b/backends/qualcomm/builders/op_instance_norm.py
@@ -35,7 +35,9 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node, weight_node, bias_node = node.args[0:3]
+ input_node = self.get_node(node.args[0])
+ weight_node = self.get_node(node.args[1])
+ bias_node = self.get_node(node.args[2])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_layer_norm.py b/backends/qualcomm/builders/op_layer_norm.py
index 06f822014ed..5316cb1dabe 100644
--- a/backends/qualcomm/builders/op_layer_norm.py
+++ b/backends/qualcomm/builders/op_layer_norm.py
@@ -30,7 +30,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
@@ -53,7 +53,7 @@ def define_node(
axis = [len(input_tensor.shape) - 1]
axis_shape = [len(axis)]
- weight_node = node.args[2]
+ weight_node = self.get_node(node.args[2])
weight_tensor = get_parameter(weight_node, self.edge_program)
weight_tensor_wrapper = self.define_tensor(
weight_node,
@@ -65,7 +65,7 @@ def define_node(
layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper]
- bias_node = node.args[3]
+ bias_node = self.get_node(node.args[3])
if bias_node is not None:
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
diff --git a/backends/qualcomm/builders/op_le.py b/backends/qualcomm/builders/op_le.py
index 1dd2a06b777..e5784049c5c 100644
--- a/backends/qualcomm/builders/op_le.py
+++ b/backends/qualcomm/builders/op_le.py
@@ -37,7 +37,7 @@ def define_node(
input_tensors = []
for index in range(2):
- input_node = node.args[index]
+ input_node = self.get_node(node.args[index])
input_tensor = self.get_tensor(input_node, node)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
diff --git a/backends/qualcomm/builders/op_linear.py b/backends/qualcomm/builders/op_linear.py
index 71b6072b9e5..71716e81bca 100644
--- a/backends/qualcomm/builders/op_linear.py
+++ b/backends/qualcomm/builders/op_linear.py
@@ -34,7 +34,7 @@ def define_node(
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
linear_input_tensors = []
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
@@ -45,7 +45,7 @@ def define_node(
)
linear_input_tensors.append(input_tensor_wrapper)
- weight_node = node.args[1]
+ weight_node = self.get_node(node.args[1])
if (
quant_attrs := weight_node.meta.get(QCOM_QUANT_ATTRS)
) and QCOM_SCALES in quant_attrs:
@@ -67,7 +67,7 @@ def define_node(
linear_input_tensors.append(weight_tensor_wrapper)
if len(node.args) >= 3:
- bias_node = node.args[2]
+ bias_node = self.get_node(node.args[2])
# TODO remove this when qnn sdk support
if QCOM_SCALES in bias_node.meta.get(QCOM_QUANT_ATTRS, {}):
diff --git a/backends/qualcomm/builders/op_log.py b/backends/qualcomm/builders/op_log.py
index bcc40aa6268..65125e42316 100644
--- a/backends/qualcomm/builders/op_log.py
+++ b/backends/qualcomm/builders/op_log.py
@@ -25,7 +25,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
log_inp_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_log_softmax.py b/backends/qualcomm/builders/op_log_softmax.py
index d395d5eb66e..2d6c857591e 100644
--- a/backends/qualcomm/builders/op_log_softmax.py
+++ b/backends/qualcomm/builders/op_log_softmax.py
@@ -27,7 +27,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
log_softmax_inp_tensor_wrapper = self.define_tensor(
diff --git a/backends/qualcomm/builders/op_logical_not.py b/backends/qualcomm/builders/op_logical_not.py
index 457a1007ada..1eed7d894de 100644
--- a/backends/qualcomm/builders/op_logical_not.py
+++ b/backends/qualcomm/builders/op_logical_not.py
@@ -25,7 +25,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_lt.py b/backends/qualcomm/builders/op_lt.py
index b4a080efc38..9494aac9d29 100644
--- a/backends/qualcomm/builders/op_lt.py
+++ b/backends/qualcomm/builders/op_lt.py
@@ -37,7 +37,7 @@ def define_node(
input_tensors = []
for index in range(2):
- input_node = node.args[index]
+ input_node = self.get_node(node.args[index])
input_tensor = self.get_tensor(input_node, node)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
diff --git a/backends/qualcomm/builders/op_matmul.py b/backends/qualcomm/builders/op_matmul.py
index 577bcb12a42..8d45424bd62 100644
--- a/backends/qualcomm/builders/op_matmul.py
+++ b/backends/qualcomm/builders/op_matmul.py
@@ -27,7 +27,7 @@ def define_node(
) -> PyQnnWrapper.PyQnnOpWrapper:
matmul_input_tensors = []
for index in range(2):
- input_node = node.args[index]
+ input_node = self.get_node(node.args[index])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
diff --git a/backends/qualcomm/builders/op_max.py b/backends/qualcomm/builders/op_max.py
index 7d41358a266..57e119922ed 100644
--- a/backends/qualcomm/builders/op_max.py
+++ b/backends/qualcomm/builders/op_max.py
@@ -37,7 +37,7 @@ def define_node(
min_input_tensors = []
for index in range(2):
- input_node = node.args[index]
+ input_node = self.get_node(node.args[index])
input_tensor = self.get_tensor(input_node, node)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
diff --git a/backends/qualcomm/builders/op_max_pool2d.py b/backends/qualcomm/builders/op_max_pool2d.py
index 8d0087eb2c6..a0ef685acd0 100644
--- a/backends/qualcomm/builders/op_max_pool2d.py
+++ b/backends/qualcomm/builders/op_max_pool2d.py
@@ -28,7 +28,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_mean_dim.py b/backends/qualcomm/builders/op_mean_dim.py
index 313b24420db..8fb0e9e3c95 100644
--- a/backends/qualcomm/builders/op_mean_dim.py
+++ b/backends/qualcomm/builders/op_mean_dim.py
@@ -28,7 +28,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_min.py b/backends/qualcomm/builders/op_min.py
index 0df2796974d..72224500b0e 100644
--- a/backends/qualcomm/builders/op_min.py
+++ b/backends/qualcomm/builders/op_min.py
@@ -37,7 +37,7 @@ def define_node(
min_input_tensors = []
for index in range(2):
- input_node = node.args[index]
+ input_node = self.get_node(node.args[index])
input_tensor = self.get_tensor(input_node, node)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
diff --git a/backends/qualcomm/builders/op_mul.py b/backends/qualcomm/builders/op_mul.py
index 3138d3b8c9b..36e0c91cf7a 100644
--- a/backends/qualcomm/builders/op_mul.py
+++ b/backends/qualcomm/builders/op_mul.py
@@ -37,7 +37,7 @@ def define_node(
mul_input_tensors = []
for index in range(2):
- input_node = node.args[index]
+ input_node = self.get_node(node.args[index])
input_tensor = self.get_tensor(input_node, node)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
diff --git a/backends/qualcomm/builders/op_ne.py b/backends/qualcomm/builders/op_ne.py
index 0227b02efbf..e9b723a88c5 100644
--- a/backends/qualcomm/builders/op_ne.py
+++ b/backends/qualcomm/builders/op_ne.py
@@ -8,14 +8,6 @@
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
import torch
-from executorch.backends.qualcomm.utils.constants import (
- QCOM_QUANT_ATTRS,
- QCOM_QUANT_MAX,
- QCOM_QUANT_MIN,
- QCOM_SCALE,
- QCOM_ZERO_POINT,
-)
-from executorch.exir.dialects._ops import ops as exir_ops
from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpElementWiseNotEqual, QNN_OP_PACKAGE_NAME_QTI_AISW
@@ -45,38 +37,9 @@ def define_node(
input_tensors = []
for index in range(2):
- input_node = node.args[index]
- if isinstance(input_node, torch.fx.Node):
- input_tensor = self.get_tensor(input_node, node)
- tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
- else:
- scalar = input_node
- input_tensor = torch.tensor(scalar, dtype=torch.float32)
- tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
-
- # 'graph', 'name', 'op', 'target', 'args', and 'kwargs'
- input_node = torch.fx.Node(
- node.graph,
- node.name + "_runtime_scalar",
- "call_function",
- exir_ops.edge.aten.scalar_tensor.default,
- (), # args
- {}, # kwargs
- )
- # Because the output data type of the ne node is boolean.
- # We need to take the quant attr from the non-scalar node.
- if quant_attrs := node.args[index ^ 1].meta.get(QCOM_QUANT_ATTRS):
- quant_attrs = quant_attrs.copy()
- quant_range = (
- quant_attrs[QCOM_QUANT_MAX] - quant_attrs[QCOM_QUANT_MIN]
- )
- quant_attrs[QCOM_ZERO_POINT] = (
- 0 if scalar >= 0 else quant_attrs[QCOM_QUANT_MAX]
- )
- quant_attrs[QCOM_SCALE] = (
- scalar / quant_range if scalar >= 0 else -scalar / quant_range
- )
- input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
+ input_node = self.get_node(node.args[index])
+ input_tensor = self.get_tensor(input_node, node)
+ tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_neg.py b/backends/qualcomm/builders/op_neg.py
index a950a1887ab..fd48cbe2791 100644
--- a/backends/qualcomm/builders/op_neg.py
+++ b/backends/qualcomm/builders/op_neg.py
@@ -24,7 +24,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
neg_inp_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_or.py b/backends/qualcomm/builders/op_or.py
index c2751744788..483831db0f7 100644
--- a/backends/qualcomm/builders/op_or.py
+++ b/backends/qualcomm/builders/op_or.py
@@ -37,7 +37,7 @@ def define_node(
or_input_tensors = []
for index in range(2):
- input_node = node.args[index]
+ input_node = self.get_node(node.args[index])
input_tensor = self.get_tensor(input_node, node)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
diff --git a/backends/qualcomm/builders/op_pad.py b/backends/qualcomm/builders/op_pad.py
index 5ec34065f8b..7b210ed6838 100644
--- a/backends/qualcomm/builders/op_pad.py
+++ b/backends/qualcomm/builders/op_pad.py
@@ -27,7 +27,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
pad_inp_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_pow.py b/backends/qualcomm/builders/op_pow.py
index 3e89bdcfc4d..996d3b353e2 100644
--- a/backends/qualcomm/builders/op_pow.py
+++ b/backends/qualcomm/builders/op_pow.py
@@ -37,7 +37,7 @@ def define_node(
pow_output_tensors = [output_tensor_wrapper]
# tensor input
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
@@ -51,7 +51,7 @@ def define_node(
)
# exp input
- exp_node = node.args[1]
+ exp_node = self.get_node(node.args[1])
exp_tensor = self.get_tensor(exp_node, node)
exp_tensor_wrapper = self.define_tensor(
exp_node,
diff --git a/backends/qualcomm/builders/op_prelu.py b/backends/qualcomm/builders/op_prelu.py
index e35839f535e..69ea5e005a7 100644
--- a/backends/qualcomm/builders/op_prelu.py
+++ b/backends/qualcomm/builders/op_prelu.py
@@ -26,7 +26,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
prelu_inp_tensor_wrapper = self.define_tensor(
input_node,
@@ -36,23 +36,18 @@ def define_node(
nodes_to_wrappers,
)
- coeff_node = node.args[1]
- coeff_tensor = torch.zeros(input_node.meta["val"].shape)
+ coeff_node = self.get_node(node.args[1])
coeff = get_parameter(coeff_node, self.edge_program)
- # param nodes will be FakeTensor when doing partition
- # fill in random numeric for validation
- if isinstance(coeff, torch._subclasses.fake_tensor.FakeTensor):
- coeff = torch.ones(coeff.shape)
+ coeff_tensor = torch.zeros(input_node.meta["val"].shape, dtype=coeff.dtype)
# per-channel activation
if coeff_node.meta["val"].shape[0] > 1:
for i in range(input_node.meta["val"].shape[1]):
coeff_tensor = coeff_tensor.index_fill(1, torch.tensor([i]), coeff[i])
- if QCOM_AXIS_ORDER in input_node.meta:
- axis_order = input_node.meta[QCOM_AXIS_ORDER]
- coeff_tensor = coeff_tensor.permute(dims=axis_order).contiguous()
else:
- coeff = coeff.item()
- coeff_tensor = torch.full(input_tensor.shape, coeff).to(torch.float32)
+ coeff_tensor.fill_(coeff[0])
+
+ if axis_order := input_node.meta.get(QCOM_AXIS_ORDER, None):
+ coeff_tensor = coeff_tensor.permute(dims=axis_order).contiguous()
coeff_tensor_wrapper = self.define_tensor(
coeff_node,
diff --git a/backends/qualcomm/builders/op_relu.py b/backends/qualcomm/builders/op_relu.py
index 29335797e28..d237b84efe1 100644
--- a/backends/qualcomm/builders/op_relu.py
+++ b/backends/qualcomm/builders/op_relu.py
@@ -25,7 +25,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
relu_inp_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_repeat.py b/backends/qualcomm/builders/op_repeat.py
index 9748f1e9619..e5867e64447 100644
--- a/backends/qualcomm/builders/op_repeat.py
+++ b/backends/qualcomm/builders/op_repeat.py
@@ -27,7 +27,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_reshape.py b/backends/qualcomm/builders/op_reshape.py
index ff4a603fa5b..6e25c65e16d 100644
--- a/backends/qualcomm/builders/op_reshape.py
+++ b/backends/qualcomm/builders/op_reshape.py
@@ -25,7 +25,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_rms_norm.py b/backends/qualcomm/builders/op_rms_norm.py
index aa7f9becd98..fdf49b09fef 100644
--- a/backends/qualcomm/builders/op_rms_norm.py
+++ b/backends/qualcomm/builders/op_rms_norm.py
@@ -36,7 +36,7 @@ def define_node(
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
# args of node : ['input', 'normalized_shape', 'weight', 'eps']
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
@@ -60,7 +60,7 @@ def define_node(
axes = [node.args[0].meta["val"].dim() - 1]
axes_shape = [len(axes)]
- weight_node = node.args[2]
+ weight_node = self.get_node(node.args[2])
weight_tensor = get_parameter(weight_node, self.edge_program)
weight_tensor_wrapper = self.define_tensor(
weight_node,
@@ -71,7 +71,7 @@ def define_node(
)
# Fake node, nn module seems to be inconsistent with document
- bias_tensor = torch.zeros(weight_tensor.shape)
+ bias_tensor = torch.zeros(weight_tensor.shape, dtype=weight_tensor.dtype)
bias_node = torch.fx.Node(
node.graph,
node.name + "_runtime_bias",
diff --git a/backends/qualcomm/builders/op_rsqrt.py b/backends/qualcomm/builders/op_rsqrt.py
index 162b485e9e5..b1995e28dde 100644
--- a/backends/qualcomm/builders/op_rsqrt.py
+++ b/backends/qualcomm/builders/op_rsqrt.py
@@ -25,7 +25,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
rsqrt_inp_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_scalar_tensor.py b/backends/qualcomm/builders/op_scalar_tensor.py
index d236f6674df..2e9154115bc 100644
--- a/backends/qualcomm/builders/op_scalar_tensor.py
+++ b/backends/qualcomm/builders/op_scalar_tensor.py
@@ -13,7 +13,7 @@
@register_node_visitor
-class Arange(NodeVisitor):
+class ScalarTensor(NodeVisitor):
target = ["scalar_tensor.default"]
def __init__(self, *args) -> None:
diff --git a/backends/qualcomm/builders/op_select_copy.py b/backends/qualcomm/builders/op_select_copy.py
index 148888f1497..c5a7c0f7c99 100644
--- a/backends/qualcomm/builders/op_select_copy.py
+++ b/backends/qualcomm/builders/op_select_copy.py
@@ -28,7 +28,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_sigmoid.py b/backends/qualcomm/builders/op_sigmoid.py
index ae6e6709c0a..ce820c8f4ee 100644
--- a/backends/qualcomm/builders/op_sigmoid.py
+++ b/backends/qualcomm/builders/op_sigmoid.py
@@ -25,7 +25,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
sigmoid_inp_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_sin.py b/backends/qualcomm/builders/op_sin.py
index 8828685ac9e..f9a0b1c2e63 100644
--- a/backends/qualcomm/builders/op_sin.py
+++ b/backends/qualcomm/builders/op_sin.py
@@ -25,7 +25,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_slice_copy.py b/backends/qualcomm/builders/op_slice_copy.py
index 8d12e03c0bb..7d3a154e9f1 100644
--- a/backends/qualcomm/builders/op_slice_copy.py
+++ b/backends/qualcomm/builders/op_slice_copy.py
@@ -26,7 +26,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
@@ -50,12 +50,17 @@ def define_node(
dim = cast(int, node.args[1])
if dim < 0:
dim = dim % len(input_tensor.shape)
- start = cast(int, node.args[2])
+
+ start = 0 if node.args[2] is None else cast(int, node.args[2])
if start < 0:
start = start % input_tensor.shape[dim]
- end = min(cast(int, node.args[3]), input_tensor.shape[dim])
- if end < 0:
- end = end % input_tensor.shape[dim]
+
+ if len(node.args) > 3:
+ end = min(cast(int, node.args[3]), input_tensor.shape[dim])
+ if end < 0:
+ end = end % input_tensor.shape[dim]
+ else:
+ end = input_tensor.shape[dim]
input_tensor_rank = len(input_tensor.shape)
ranges = []
diff --git a/backends/qualcomm/builders/op_softmax.py b/backends/qualcomm/builders/op_softmax.py
index f6f826e2a40..43cc6438b9b 100644
--- a/backends/qualcomm/builders/op_softmax.py
+++ b/backends/qualcomm/builders/op_softmax.py
@@ -27,7 +27,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
softmax_inp_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_space_to_depth.py b/backends/qualcomm/builders/op_space_to_depth.py
index 0282cf3f15a..84c79d841d8 100644
--- a/backends/qualcomm/builders/op_space_to_depth.py
+++ b/backends/qualcomm/builders/op_space_to_depth.py
@@ -28,7 +28,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_split_with_sizes.py b/backends/qualcomm/builders/op_split_with_sizes.py
index 138f6ed60ec..b70d74aa339 100644
--- a/backends/qualcomm/builders/op_split_with_sizes.py
+++ b/backends/qualcomm/builders/op_split_with_sizes.py
@@ -28,7 +28,7 @@ def define_node(
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
diff --git a/backends/qualcomm/builders/op_sqrt.py b/backends/qualcomm/builders/op_sqrt.py
index 5505e92ee67..ff5a0c086e0 100644
--- a/backends/qualcomm/builders/op_sqrt.py
+++ b/backends/qualcomm/builders/op_sqrt.py
@@ -26,7 +26,7 @@ def define_node(
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
# tensor input
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
diff --git a/backends/qualcomm/builders/op_squeeze.py b/backends/qualcomm/builders/op_squeeze.py
index b828bb7b0b9..94d6e5a3cf9 100644
--- a/backends/qualcomm/builders/op_squeeze.py
+++ b/backends/qualcomm/builders/op_squeeze.py
@@ -25,7 +25,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
diff --git a/backends/qualcomm/builders/op_stack.py b/backends/qualcomm/builders/op_stack.py
index fdef148ad4d..25b7d353dc4 100644
--- a/backends/qualcomm/builders/op_stack.py
+++ b/backends/qualcomm/builders/op_stack.py
@@ -30,7 +30,7 @@ def define_node(
input_node_list = node.args[0]
stack_input_tensors = []
for input_node in input_node_list:
- input_tensor = self.get_tensor(input_node, node)
+ input_tensor = self.get_tensor(self.get_node(input_node), node)
stack_inp_tensor_wrapper = self.define_tensor(
input_node,
node,
diff --git a/backends/qualcomm/builders/op_sub.py b/backends/qualcomm/builders/op_sub.py
index 954ca9d3917..e7e5b22bb96 100644
--- a/backends/qualcomm/builders/op_sub.py
+++ b/backends/qualcomm/builders/op_sub.py
@@ -37,7 +37,7 @@ def define_node(
sub_input_tensors = []
for index in range(2):
- input_node = node.args[index]
+ input_node = self.get_node(node.args[index])
input_tensor = self.get_tensor(input_node, node)
tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
diff --git a/backends/qualcomm/builders/op_sum_int_list.py b/backends/qualcomm/builders/op_sum_int_list.py
index 74181f46cb3..fc5546f9d33 100644
--- a/backends/qualcomm/builders/op_sum_int_list.py
+++ b/backends/qualcomm/builders/op_sum_int_list.py
@@ -28,7 +28,7 @@ def define_node(
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_tanh.py b/backends/qualcomm/builders/op_tanh.py
index ddc9fd2a2a6..c06f44b312f 100644
--- a/backends/qualcomm/builders/op_tanh.py
+++ b/backends/qualcomm/builders/op_tanh.py
@@ -26,7 +26,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_to.py b/backends/qualcomm/builders/op_to.py
index 5fb016aef95..dc1062846ed 100644
--- a/backends/qualcomm/builders/op_to.py
+++ b/backends/qualcomm/builders/op_to.py
@@ -80,7 +80,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
diff --git a/backends/qualcomm/builders/op_topk.py b/backends/qualcomm/builders/op_topk.py
index 745cf7b9935..2b5d23268b9 100644
--- a/backends/qualcomm/builders/op_topk.py
+++ b/backends/qualcomm/builders/op_topk.py
@@ -33,7 +33,7 @@ def define_node(
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_transpose.py b/backends/qualcomm/builders/op_transpose.py
index d29fc73084c..7fb02a2fb7c 100644
--- a/backends/qualcomm/builders/op_transpose.py
+++ b/backends/qualcomm/builders/op_transpose.py
@@ -28,7 +28,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
permute_node = input_node if QCOM_INSERTED_PERMUTE in node.meta else node
input_tensor = self.get_tensor(input_node, permute_node)
input_tensor_wrapper = self.define_tensor(
diff --git a/backends/qualcomm/builders/op_unbind.py b/backends/qualcomm/builders/op_unbind.py
index 8ca62e2a07b..1c505e6f4fd 100644
--- a/backends/qualcomm/builders/op_unbind.py
+++ b/backends/qualcomm/builders/op_unbind.py
@@ -27,7 +27,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_unsqueeze.py b/backends/qualcomm/builders/op_unsqueeze.py
index 55790129462..f5cd7af3b2e 100644
--- a/backends/qualcomm/builders/op_unsqueeze.py
+++ b/backends/qualcomm/builders/op_unsqueeze.py
@@ -25,7 +25,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
diff --git a/backends/qualcomm/builders/op_upsample_bilinear2d.py b/backends/qualcomm/builders/op_upsample_bilinear2d.py
index 654fb934571..10dfe375fe0 100644
--- a/backends/qualcomm/builders/op_upsample_bilinear2d.py
+++ b/backends/qualcomm/builders/op_upsample_bilinear2d.py
@@ -26,7 +26,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_upsample_nearest2d.py b/backends/qualcomm/builders/op_upsample_nearest2d.py
index c4b353fd3e9..4e9c4741ca2 100644
--- a/backends/qualcomm/builders/op_upsample_nearest2d.py
+++ b/backends/qualcomm/builders/op_upsample_nearest2d.py
@@ -26,7 +26,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- input_node = node.args[0]
+ input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
diff --git a/backends/qualcomm/builders/op_where.py b/backends/qualcomm/builders/op_where.py
index ecac45a7a6f..94ee1b0e940 100644
--- a/backends/qualcomm/builders/op_where.py
+++ b/backends/qualcomm/builders/op_where.py
@@ -25,7 +25,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
- conditional_input_node = node.args[0]
+ conditional_input_node = self.get_node(node.args[0])
conditional_input_tensor = self.get_tensor(conditional_input_node, node)
conditional_input_tensor_wrapper = self.define_tensor(
conditional_input_node,
@@ -35,7 +35,7 @@ def define_node(
nodes_to_wrappers,
)
- true_input_node = node.args[1]
+ true_input_node = self.get_node(node.args[1])
true_input_tensor = self.get_tensor(true_input_node, node)
true_input_tensor_wrapper = self.define_tensor(
true_input_node,
@@ -45,7 +45,7 @@ def define_node(
nodes_to_wrappers,
)
- false_input_node = node.args[2]
+ false_input_node = self.get_node(node.args[2])
false_input_tensor = self.get_tensor(false_input_node, node)
false_input_tensor_wrapper = self.define_tensor(
false_input_node,
diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py
index d9eb188614c..7e5a779e748 100644
--- a/backends/qualcomm/partition/qnn_partitioner.py
+++ b/backends/qualcomm/partition/qnn_partitioner.py
@@ -12,7 +12,10 @@
from executorch.backends.qualcomm.builders import node_visitor
from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader
from executorch.backends.qualcomm.qnn_preprocess import QnnBackend
-from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER
+from executorch.backends.qualcomm.utils.constants import (
+ QCOM_AXIS_ORDER,
+ QCOM_BYPASS_NODE,
+)
from executorch.exir.backend.backend_details import CompileSpec
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
@@ -46,7 +49,6 @@ def __init__(
skip_node_op_set: set = None,
):
self.node_visitors = node_visitor.get_node_visitors(edge_program)
-
self.skip_node_op_set = skip_node_op_set
self.skip_node_id_set = skip_node_id_set
self.nodes_to_wrappers = defaultdict(dict)
@@ -70,6 +72,8 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool:
node.target in allow_list_operator
# bypass if custom op appears
or OpContextLoader.namespace == node.target.namespace
+ # bypass dequantize op for parameters & buffers
+ or node.meta.get(QCOM_BYPASS_NODE, False)
):
return True
diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py
index 338209fcd4a..74c85b773c2 100644
--- a/backends/qualcomm/tests/test_qnn_delegate.py
+++ b/backends/qualcomm/tests/test_qnn_delegate.py
@@ -1535,8 +1535,11 @@ def test_qnn_backend_elu(self):
def test_qnn_backend_embedding(self):
module = Embedding() # noqa: F405
sample_input = (torch.Tensor([[1, 2, 4, 5], [4, 3, 2, 9]]).to(torch.int32),)
- module = self.get_qdq_module(module, sample_input)
- self.lower_module_and_test_output(module, sample_input)
+ quant_dtype = [QuantDtype.use_8a8w, QuantDtype.use_16a4w]
+ for i, qdtype in enumerate(quant_dtype):
+ with self.subTest(i=i):
+ module = self.get_qdq_module(module, sample_input, quant_dtype=qdtype)
+ self.lower_module_and_test_output(module, sample_input)
def test_qnn_backend_equal(self):
test_comb = [
@@ -2505,7 +2508,7 @@ def test_qnn_backend_profile_op(self):
module,
sample_input,
expected_partitions=1,
- expected_profile_events=34,
+ expected_profile_events=30,
)
def test_qnn_backend_shared_buffer(self):
@@ -2527,6 +2530,9 @@ def test_qnn_backend_shared_buffer(self):
)
def test_qnn_backend_online_prepare(self):
+ if self.enable_x86_64:
+ self.skipTest("online prepare is not supported on host machine")
+
backend_options = generate_htp_compiler_spec(use_fp16=True)
TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
soc_model=self.chipset_table[TestQNN.model],
@@ -3150,7 +3156,7 @@ def test_qnn_backend_profile_op(self):
module,
sample_input,
expected_partitions=1,
- expected_profile_events=35,
+ expected_profile_events=30,
)
def test_qnn_backend_shared_buffer(self):
@@ -3173,6 +3179,9 @@ def test_qnn_backend_shared_buffer(self):
)
def test_qnn_backend_online_prepare(self):
+ if self.enable_x86_64:
+ self.skipTest("online prepare is not supported on host machine")
+
backend_options = generate_htp_compiler_spec(use_fp16=False)
TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
soc_model=self.chipset_table[TestQNN.model],
@@ -3300,17 +3309,17 @@ def test_qnn_backend_dump_context_from_pte(self):
def test_qnn_backend_draw_graph(self):
golden_data = """digraph test {
rankdir=TB
- aten_convolution_default_0 [label=<
+ aten_convolution_default_1_0 [label=<
- | name: aten_convolution_default_0 |
+ | name: aten_convolution_default_1_0 |
| data_type: Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8 |
| tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE |
| dims: [1, 28, 28, 32] |
| quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET |
> color=black fillcolor=transparent shape=box style=rounded]
- aten_relu_default_0 [label=<
+ aten_relu_default_1_0 [label=<
- | name: aten_relu_default_0 |
+ | name: aten_relu_default_1_0 |
| data_type: Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8 |
| tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE |
| dims: [1, 28, 28, 32] |
@@ -3340,17 +3349,33 @@ def test_qnn_backend_draw_graph(self):
| dims: [32] |
| quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET |
> color=black fillcolor=transparent shape=box style=rounded]
- aten_convolution_default_1_0 [label=<
+ b__frozen_param0_0 [label=<
- | name: aten_convolution_default_1_0 |
+ | name: b__frozen_param0_0 |
+ | data_type: Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8 |
+ | tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC |
+ | dims: [3, 3, 32, 32] |
+ | quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET |
+
> color=black fillcolor=transparent shape=box style=rounded]
+ b__frozen_param1_0 [label=<
+
+ | name: b__frozen_param1_0 |
+ | data_type: Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_32 |
+ | tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC |
+ | dims: [32] |
+ | quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET |
+
> color=black fillcolor=transparent shape=box style=rounded]
+ aten_convolution_default_0 [label=<
+
+ | name: aten_convolution_default_0 |
| data_type: Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8 |
| tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE |
| dims: [1, 28, 28, 32] |
| quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET |
> color=black fillcolor=transparent shape=box style=rounded]
- aten_relu_default_1_0 [label=<
+ aten_relu_default_0 [label=<
- | name: aten_relu_default_1_0 |
+ | name: aten_relu_default_0 |
| data_type: Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8 |
| tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE |
| dims: [1, 28, 28, 32] |
@@ -3364,14 +3389,6 @@ def test_qnn_backend_draw_graph(self):
| dims: [1, 28, 28, 32] |
| quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET |
> color=black fillcolor=transparent shape=box style=rounded]
- output_quantized_decomposed_dequantize_per_tensor_tensor_0 [label=<
-
- | name: output_quantized_decomposed_dequantize_per_tensor_tensor_0 |
- | data_type: Qnn_DataType_t.QNN_DATATYPE_FLOAT_32 |
- | tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ |
- | dims: [1, 32, 28, 28] |
- | quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED |
-
> color=black fillcolor=transparent shape=box style=rounded]
input_0_x_0 [label=<
| name: input_0_x_0 |
@@ -3380,36 +3397,27 @@ def test_qnn_backend_draw_graph(self):
| dims: [1, 32, 28, 28] |
| quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED |
> color=black fillcolor=transparent shape=box style=rounded]
- b__frozen_param0_0 [label=<
-
- | name: b__frozen_param0_0 |
- | data_type: Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8 |
- | tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC |
- | dims: [3, 3, 32, 32] |
- | quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET |
-
> color=black fillcolor=transparent shape=box style=rounded]
- b__frozen_param1_0 [label=<
+ output_quantized_decomposed_dequantize_per_tensor_tensor_0 [label=<
- | name: b__frozen_param1_0 |
- | data_type: Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_32 |
- | tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC |
- | dims: [32] |
- | quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET |
+ | name: output_quantized_decomposed_dequantize_per_tensor_tensor_0 |
+ | data_type: Qnn_DataType_t.QNN_DATATYPE_FLOAT_32 |
+ | tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ |
+ | dims: [1, 32, 28, 28] |
+ | quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED |
> color=black fillcolor=transparent shape=box style=rounded]
- quantized_decomposed_quantize_per_tensor_default_0 -> aten_convolution_default_0
- input_0_x_0 -> quantized_decomposed_quantize_per_tensor_default_0
- b__frozen_param0_0 -> aten_convolution_default_0
- b__frozen_param1_0 -> aten_convolution_default_0
- aten_convolution_default_0 -> aten_relu_default_0
quantized_decomposed_quantize_per_tensor_default_0 -> aten_convolution_default_1_0
+ input_0_x_0 -> quantized_decomposed_quantize_per_tensor_default_0
b__frozen_param2_0 -> aten_convolution_default_1_0
b__frozen_param3_0 -> aten_convolution_default_1_0
aten_convolution_default_1_0 -> aten_relu_default_1_0
+ quantized_decomposed_quantize_per_tensor_default_0 -> aten_convolution_default_0
+ b__frozen_param0_0 -> aten_convolution_default_0
+ b__frozen_param1_0 -> aten_convolution_default_0
+ aten_convolution_default_0 -> aten_relu_default_0
aten_relu_default_0 -> aten_add_tensor_0
aten_relu_default_1_0 -> aten_add_tensor_0
aten_add_tensor_0 -> output_quantized_decomposed_dequantize_per_tensor_tensor_0
- }
- """
+ }"""
module = DrawGraphModel() # noqa: F405
sample_input = (torch.randn(1, 32, 28, 28),)
module = self.get_qdq_module(module, sample_input)
diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py
index 695c846de05..0b34290f4c2 100644
--- a/backends/qualcomm/tests/utils.py
+++ b/backends/qualcomm/tests/utils.py
@@ -15,6 +15,7 @@
import numpy as np
import torch
from executorch import exir
+from executorch.backends.qualcomm._passes.utils import dq_ops
from executorch.backends.qualcomm.qnn_preprocess import QnnBackend
from executorch.backends.qualcomm.quantizer.quantizer import ModuleQConfig, QuantDtype
from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
@@ -298,7 +299,7 @@ def validate_profile():
target_time_scale=TimeScale.CYCLES,
)
self.assertTrue(
- len(inspector.to_dataframe().index) == expected_profile_events
+ len(inspector.to_dataframe().index) >= expected_profile_events
)
def validate_intermediate_tensor():
@@ -583,7 +584,7 @@ def get_converted_sgd_trained_module(
optimizer.zero_grad()
loss.backward()
optimizer.step()
- return torch.ao.quantization.quantize_pt2e.convert_pt2e(prepared)
+ return convert_pt2e(prepared)
def split_graph(self, division: int):
class SplitGraph(ExportPass):
@@ -595,6 +596,10 @@ def __init__(self, division):
super().__init__()
self.division = division
+ def _is_legit_node(self, node):
+ # skip dq_ops for frozen_params
+ return node.op == "call_function" and node.target not in dq_ops
+
def _insert_clone(
self, graph_module: torch.fx.GraphModule
) -> torch.fx.GraphModule:
@@ -609,9 +614,11 @@ def _insert_clone(
# Insert clone op to split model based on the shares
num_graph_nodes = 0
for node in graph_module.graph.nodes:
- num_graph_nodes += 1 if node.op == "call_function" else 0
+ if not self._is_legit_node(node):
+ continue
- if num_graph_nodes % shares != 0 or node.op != "call_function":
+ num_graph_nodes += 1
+ if num_graph_nodes % shares != 0:
continue
with graph_module.graph.inserting_after(node):
diff --git a/backends/qualcomm/utils/constants.py b/backends/qualcomm/utils/constants.py
index ce917bf4115..a4a087287a4 100644
--- a/backends/qualcomm/utils/constants.py
+++ b/backends/qualcomm/utils/constants.py
@@ -15,6 +15,7 @@
QCOM_BLOCK_SCALE_BITWIDTH = "block_scale_bitwidth"
QCOM_BLOCK_SCALE_OFFSET = "block_scale_offset"
QCOM_BLOCK_STORAGE_TYPE = "block_storage_type"
+QCOM_BYPASS_NODE = "bypass_node"
QCOM_DATA = "data"
QCOM_DTYPE = "dtype"
QCOM_ENCODING = "encoding"
diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py
index 3653cd3176f..e80b03a32c0 100644
--- a/backends/qualcomm/utils/utils.py
+++ b/backends/qualcomm/utils/utils.py
@@ -640,7 +640,8 @@ def prepare_subgm(subgm, subgm_name):
for node in graph_module.graph.nodes:
if node.op == "call_module":
graph_module.set_submodule(
- node.name, convert_pt2e(graph_module.get_submodule(node.name))
+ node.name,
+ convert_pt2e(graph_module.get_submodule(node.name)),
)
# canonicalize graph for lowering again
graph_module, edge_prog_mgrs = _canonicalize_graph_with_lowered_module(
diff --git a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt
index 4d4f1c2e39d..b9b1ddc0f72 100644
--- a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt
+++ b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt
@@ -52,6 +52,7 @@ target_link_libraries(
qnn_executorch_backend
executorch_core
extension_data_loader
+ extension_flat_tensor
extension_module
extension_tensor
gflags
diff --git a/examples/qualcomm/qaihub_scripts/llama/CMakeLists.txt b/examples/qualcomm/qaihub_scripts/llama/CMakeLists.txt
index 16d91013349..4e44a1599b1 100644
--- a/examples/qualcomm/qaihub_scripts/llama/CMakeLists.txt
+++ b/examples/qualcomm/qaihub_scripts/llama/CMakeLists.txt
@@ -39,6 +39,7 @@ target_link_libraries(
qnn_executorch_backend
executorch_core
extension_data_loader
+ extension_flat_tensor
extension_module
extension_tensor
gflags
@@ -87,6 +88,7 @@ target_link_libraries(
qnn_executorch_backend
executorch_core
extension_data_loader
+ extension_flat_tensor
extension_module
extension_tensor
gflags
diff --git a/examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt b/examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt
index ff22f08cd09..5b63a6678fc 100644
--- a/examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt
+++ b/examples/qualcomm/qaihub_scripts/stable_diffusion/CMakeLists.txt
@@ -23,6 +23,7 @@ target_link_libraries(
qnn_executorch_backend
executorch_core
extension_data_loader
+ extension_flat_tensor
extension_module
extension_tensor
gflags
diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py
index d8dab88e998..b5ac77c3d1f 100755
--- a/examples/qualcomm/utils.py
+++ b/examples/qualcomm/utils.py
@@ -251,8 +251,8 @@ def qat_train(ori_model, captured_model, quantizer, dataset):
loss.backward()
optimizer.step()
- return torch.ao.quantization.quantize_pt2e.convert_pt2e(
- torch.ao.quantization.move_exported_model_to_eval(annotated_model)
+ return convert_pt2e(
+ torch.ao.quantization.move_exported_model_to_eval(annotated_model),
)
diff --git a/exir/program/_program.py b/exir/program/_program.py
index f24807e253d..155591dd1af 100644
--- a/exir/program/_program.py
+++ b/exir/program/_program.py
@@ -1283,7 +1283,7 @@ def to_edge_transform_and_lower(
@experimental(
"""
- This is an experimental API which overloads to_edge by preserving specified ops to not be decomposed.
+ This is an experimental API which overloads to_edge by preserving specified ops to not be decomposed.
This function will be combined with to_edge in the future.
"""
)