diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 13e426ccd7a0..a6d536eaccf7 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -25,7 +25,7 @@ from tvm.relay.frontend.common import infer_shape from .common import logger -from .pytorch_utils import is_version_greater_than, getattr_attr_name +from .pytorch_utils import is_version_greater_than class QNNParam(object): @@ -540,18 +540,11 @@ def inline_input_quant_params_for_fx(graph, params, param_debug_name_map): # pylint: disable=c-extension-no-member import torch - def get_full_attr_name(current): - current_attr = getattr_attr_name(current) - inputs = list(current.inputs()) - if len(inputs) == 1 and inputs[0].node().kind() == "prim::GetAttr": - return get_full_attr_name(inputs[0].node()) + "." + current_attr - return current_attr - for node in graph.findAllNodes("prim::GetAttr", recurse=True): out_name = node.output().debugName() if "_scale" in out_name or "_zero_point" in out_name: - full_attr = param_debug_name_map[get_full_attr_name(node)] + full_attr = param_debug_name_map[out_name] assert full_attr in params, f"{full_attr} not found in param dict." param_np = params[full_attr].numpy() new_const_node = graph.create("prim::Constant")