From 0952a714cfa052b6e5358e22c2b0fd6153137f85 Mon Sep 17 00:00:00 2001 From: PineApple777 Date: Fri, 8 Dec 2023 19:13:22 +0900 Subject: [PATCH 1/2] update qnn_torch.py --- python/tvm/relay/frontend/qnn_torch.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 13e426ccd7a0..676784053592 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -540,18 +540,11 @@ def inline_input_quant_params_for_fx(graph, params, param_debug_name_map): # pylint: disable=c-extension-no-member import torch - def get_full_attr_name(current): - current_attr = getattr_attr_name(current) - inputs = list(current.inputs()) - if len(inputs) == 1 and inputs[0].node().kind() == "prim::GetAttr": - return get_full_attr_name(inputs[0].node()) + "." + current_attr - return current_attr - for node in graph.findAllNodes("prim::GetAttr", recurse=True): out_name = node.output().debugName() if "_scale" in out_name or "_zero_point" in out_name: - full_attr = param_debug_name_map[get_full_attr_name(node)] + full_attr = param_debug_name_map[out_name] assert full_attr in params, f"{full_attr} not found in param dict." param_np = params[full_attr].numpy() new_const_node = graph.create("prim::Constant") From 53aee6b4088d73a453c1cc4e4ccddb6a03128dec Mon Sep 17 00:00:00 2001 From: PineApple777 Date: Sat, 23 Dec 2023 18:30:14 +0900 Subject: [PATCH 2/2] remove unused function --- python/tvm/relay/frontend/qnn_torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 676784053592..a6d536eaccf7 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -25,7 +25,7 @@ from tvm.relay.frontend.common import infer_shape from .common import logger -from .pytorch_utils import is_version_greater_than, getattr_attr_name +from .pytorch_utils import is_version_greater_than class QNNParam(object):