From 6fa631ddf8f3839817dfd8b913b37f7fa21fe17f Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Mon, 7 Apr 2025 05:46:15 +0000 Subject: [PATCH 1/2] Update K path in conformer attention fusion --- .../python/tools/transformers/fusion_conformer_attention.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py index 0f0c12b0e0200..398b0312c33a0 100644 --- a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py @@ -122,8 +122,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if k_nodes is None: k_nodes = self.model.match_parent_path( matmul_qk, - ["Transpose", "Transpose", "Reshape", "Add", "MatMul"], - [1, 0, 0, 0, 0], + ["Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, 0], ) if k_nodes is None: logger.debug("fuse_conformer_attention: failed to match k path") @@ -188,7 +188,6 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed") return - self.increase_counter(new_node.op_type) self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name From 80239227ca57855c1326d6336c1fb81f23ab22d7 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Wed, 23 Apr 2025 18:23:25 +0000 Subject: [PATCH 2/2] Keep original path when matching --- .../transformers/fusion_conformer_attention.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py index 398b0312c33a0..2b7fbffa842f7 100644 --- a/onnxruntime/python/tools/transformers/fusion_conformer_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_conformer_attention.py @@ -122,12 +122,18 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if k_nodes is None: k_nodes = self.model.match_parent_path( matmul_qk, - ["Transpose", "Reshape", "Add", "MatMul"], - [1, 0, 0, 0], + ["Transpose", "Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, 0, 0], ) if k_nodes is None: - logger.debug("fuse_conformer_attention: failed to match k path") - return + k_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, 0], + ) + if k_nodes is None: + logger.debug("fuse_conformer_attention: failed to match k path") + return else: concat_k = k_nodes[1] concat_parent = self.model.get_parent(concat_k, 0, None)