diff --git a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py index 0f0c12b0e0200..2b7fbffa842f7 100644 --- a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py @@ -126,8 +126,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): [1, 0, 0, 0, 0], ) if k_nodes is None: - logger.debug("fuse_conformer_attention: failed to match k path") - return + k_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, 0], + ) + if k_nodes is None: + logger.debug("fuse_conformer_attention: failed to match k path") + return else: concat_k = k_nodes[1] concat_parent = self.model.get_parent(concat_k, 0, None) @@ -188,7 +194,6 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed") return - self.increase_counter(new_node.op_type) self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name