From 8755e1f8f32aae6fc165bfecc26be3ce16d81f16 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Tue, 14 May 2024 19:56:43 -0700 Subject: [PATCH 1/5] Add OpenAI CLIP fusions --- .../transformers/fusion_attention_clip.py | 81 +++++++++++++----- .../tools/transformers/fusion_layernorm.py | 34 +++++--- .../tools/transformers/fusion_quickgelu.py | 84 +++++++++++++++++++ .../tools/transformers/onnx_model_bert.py | 5 ++ .../tools/transformers/onnx_model_clip.py | 1 + 5 files changed, 174 insertions(+), 31 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/fusion_quickgelu.py diff --git a/onnxruntime/python/tools/transformers/fusion_attention_clip.py b/onnxruntime/python/tools/transformers/fusion_attention_clip.py index d400e248d6cca..c30671e23692c 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_clip.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_clip.py @@ -97,7 +97,33 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): else: # Deal with the first attention after the embedding layer. for i in [0, 1]: - node_before_layer_norm = self.model.match_parent(normalize_node, "Add", i) + node_before_layer_norm = None + + node_before_layer_norm_1 = self.model.match_parent(normalize_node, "Add", i) + node_before_layer_norm_2 = self.model.match_parent(normalize_node, "LayerNormalization", i) + if node_before_layer_norm_1 is not None: + # Add -----------+ + # | | + # LayerNorm | + # | | + # LayerNorm | + # | | + # Attention subgraph | + # | | + # SkipLayerNorm ------+ + node_before_layer_norm = node_before_layer_norm_1 + elif node_before_layer_norm_2 is not None: + # Add + # | + # LayerNorm --------+ + # | | + # LayerNorm | + # | | + # Attention subgraph | + # | | + # SkipLayerNorm ------+ + node_before_layer_norm = node_before_layer_norm_2 + if node_before_layer_norm is None: continue child = self.model.find_first_child_by_type( @@ -130,20 +156,32 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): return (_, _, reshape_v, add_v, matmul_v) = v_nodes + add_mask = None add_mask_indices = [] - qk_nodes = self.model.match_parent_path( + qk_nodes = None + qk_nodes_1 = self.model.match_parent_path( matmul_qkv, ["Softmax", "Reshape", "Add", "Reshape", "MatMul"], [0, 0, 0, None, 0], return_indice=add_mask_indices, ) - if qk_nodes is None: + qk_nodes_2 = self.model.match_parent_path( + matmul_qkv, + ["Softmax", "MatMul"], + [0, 0], + ) + if qk_nodes_1 is not None: + qk_nodes = qk_nodes_1 + assert len(add_mask_indices) == 1 + causal_mask_input_index = 1 - add_mask_indices[0] + + (_softmax_qk, _, add_mask, _, matmul_qk) = qk_nodes + elif qk_nodes_2 is not None: + qk_nodes = qk_nodes_2 + (_softmax_qk, matmul_qk) = qk_nodes + else: logger.debug("fuse_attention: failed to match qk path") return - assert len(add_mask_indices) == 1 - causal_mask_input_index = 1 - add_mask_indices[0] - - (_softmax_qk, _, add_mask, _, matmul_qk) = qk_nodes q_nodes = self.model.match_parent_path( matmul_qk, ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, None, None] @@ -172,23 +210,24 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): attention_last_node = reshape_qkv - # Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path - # of computing causal mask. - causal_mask_nodes = self.model.match_parent_path( - add_mask, - ["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], - [causal_mask_input_index, 0, 0, 0, 0, 0], - ) - if causal_mask_nodes is None: - # If the model is exported with batch_size == 1, there is no Concat node + if add_mask is not None: + # Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path + # of computing causal mask. causal_mask_nodes = self.model.match_parent_path( add_mask, - ["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], - [causal_mask_input_index, 0, 0, 0, 0], + ["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], + [causal_mask_input_index, 0, 0, 0, 0, 0], ) if causal_mask_nodes is None: - logger.debug("fuse_attention: failed to match causal mask subgraph") - return + # If the model is exported with batch_size == 1, there is no Concat node + causal_mask_nodes = self.model.match_parent_path( + add_mask, + ["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"], + [causal_mask_input_index, 0, 0, 0, 0], + ) + if causal_mask_nodes is None: + logger.debug("fuse_attention: failed to match causal mask subgraph") + return new_node = self.create_attention_node( mask_index=None, @@ -204,7 +243,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): output=attention_last_node.output[0], add_qk_str=None, scale=None, - causal=True, + causal=(add_mask is not None), ) if new_node is None: return diff --git a/onnxruntime/python/tools/transformers/fusion_layernorm.py b/onnxruntime/python/tools/transformers/fusion_layernorm.py index 68d26fc46fa23..a86d6d589e222 100644 --- a/onnxruntime/python/tools/transformers/fusion_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_layernorm.py @@ -38,6 +38,7 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): | | +----------------------+ """ + subgraph_nodes = [] children = self.model.get_children(node, input_name_to_nodes) if len(children) == 0 or len(children) > 2: return @@ -53,9 +54,17 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): div_node = None for child in children: - div_node = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False) - if div_node is not None: - break + # Check if Sub --> Div exists + div_node_1 = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False) + + # Check if Sub --> Cast --> Div + div_node_2 = self.model.match_child_path(child, ["Cast", "Div"], exclude=[]) + + if div_node_1 is not None: + div_node = div_node_1 + elif div_node_2 != []: + div_node = div_node_2[-1] + # subgraph_nodes.append(div_node_2[0]) # add Cast node to list of subgraph nodes if div_node is None: return @@ -63,10 +72,7 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): div_node, [ (["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]), - ( - ["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], - [1, 0, 0, 0, 0, 0], - ), + (["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]), ], output_name_to_node, ) @@ -87,7 +93,14 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): if self.model.find_constant_input(pow_node, 2.0) != 1: return - mul_node = input_name_to_nodes[div_node.output[0]][0] + temp_node = input_name_to_nodes[div_node.output[0]][0] + if temp_node.op_type == "Cast": + # Div --> Cast --> Mul + subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes + mul_node = input_name_to_nodes[temp_node.output[0]][0] + else: + # Div --> Mul + mul_node = temp_node if mul_node.op_type != "Mul": return @@ -95,7 +108,7 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): if last_add_node.op_type != "Add": return - subgraph_nodes = [node] + subgraph_nodes.append(node) subgraph_nodes.extend(children) subgraph_nodes.extend(parent_nodes[:-1]) @@ -109,7 +122,8 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): logger.debug("It is not safe to fuse LayerNormalization node. Skip") return - weight_input = mul_node.input[1 - self.model.input_index(div_node.output[0], mul_node)] + node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node + weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)] if not self.model.is_constant_with_specified_dimension(weight_input, 1, "layernorm weight"): return diff --git a/onnxruntime/python/tools/transformers/fusion_quickgelu.py b/onnxruntime/python/tools/transformers/fusion_quickgelu.py new file mode 100644 index 0000000000000..86caf7301d4b9 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_quickgelu.py @@ -0,0 +1,84 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from fusion_base import Fusion +from onnx import helper +from onnx_model import OnnxModel + +import logging + +logger = logging.getLogger(__name__) + + +class FusionQuickGelu(Fusion): + def __init__(self, model: OnnxModel): + super().__init__(model, "QuickGelu", ["MatMul"]) + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + # Fuse the following subgraph to `QuickGelu` + # + # root_input + # / \ + # | Mul ----+ + # | (B = ~1.702) | + # \ | | + # \ Sigmoid |---- `QuickGelu` + # \ / | + # \ / | + # Mul ----+ + # | + # MatMul [node] + + second_mul_node = self.model.match_parent_path(node, ["Mul"], [0]) + if second_mul_node is None: + logger.debug("fuse_quickgelu: failed to match second Mul node") + return + second_mul_node = second_mul_node[0] + + root_input = None + root_input_1 = self.model.match_parent_path(second_mul_node, ["Add"], [0]) + root_input_2 = self.model.match_parent_path(second_mul_node, ["MatMul"], [0]) + if root_input_1 is not None: + root_input = root_input_1[0].output[0] + elif root_input_2 is not None: + root_input = root_input_2[0].output[0] + else: + logger.debug("fuse_quickgelu: failed to match root input") + return + + sigmoid_node = self.model.match_parent_path(second_mul_node, ["Sigmoid"], [1]) + if sigmoid_node is None: + logger.debug("fuse_quickgelu: failed to match Sigmoid node") + return + sigmoid_node = sigmoid_node[0] + + first_mul_node = self.model.match_parent_path(sigmoid_node, ["Mul"], [0]) + if first_mul_node is None: + logger.debug("fuse_quickgelu: failed to match first Mul node") + return + first_mul_node = first_mul_node[0] + + approximation_value = self.model.get_constant_value(first_mul_node.input[1]).item() + if approximation_value != 1.7021484375: + logger.debug("fuse_quickgelu: failed to match approximation value") + return + + if first_mul_node.input[0] != root_input: + logger.debug("fuse_quickgelu: failed to match root input with first Mul node's input") + return + + new_node = helper.make_node( + "QuickGelu", + inputs=[root_input], + outputs=[second_mul_node.output[0]], + name=self.model.create_node_name("QuickGelu"), + ) + new_node.domain = "com.microsoft" + new_node.attribute.extend([helper.make_attribute("alpha", approximation_value)]) + + self.nodes_to_remove.extend([first_mul_node, sigmoid_node, second_mul_node]) + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + self.increase_counter("QuickGelu") diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 431e64509e3cc..5eef65df19c8c 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -21,6 +21,7 @@ from fusion_qordered_gelu import FusionQOrderedGelu from fusion_qordered_layernorm import FusionQOrderedLayerNormalization from fusion_qordered_matmul import FusionQOrderedMatMul +from fusion_quickgelu import FusionQuickGelu from fusion_reshape import FusionReshape from fusion_rotary_attention import FusionRotaryEmbeddings from fusion_shape import FusionShape @@ -65,6 +66,8 @@ def fuse_gelu(self): fusion.apply() fusion = FusionFastGelu(self) fusion.apply() + fusion = FusionQuickGelu(self) + fusion.apply() # Only relevant in models with Q-DQ nodes fusion = FusionQOrderedGelu(self) fusion.apply() @@ -347,6 +350,8 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo if (options is None) or options.enable_rotary_embeddings: self.fuse_rotary_embeddings() + # OnnxModel.save(self.model, "temp.onnx", save_as_external_data=True, all_tensors_to_one_file=True) + if options is not None: self.attention_mask.set_mask_format(options.attention_mask_format) if options.use_multi_head_attention and not isinstance(self.attention_fusion, FusionBartAttention): diff --git a/onnxruntime/python/tools/transformers/onnx_model_clip.py b/onnxruntime/python/tools/transformers/onnx_model_clip.py index 9b4ca03a47a5b..32bddc3ca16a0 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_clip.py +++ b/onnxruntime/python/tools/transformers/onnx_model_clip.py @@ -25,6 +25,7 @@ def get_fused_operator_statistics(self): ops = [ "Attention", "LayerNormalization", + "QuickGelu", "SkipLayerNormalization", ] for op in ops: From 2135cc88f298cbed98b738c655e79d68ff04b915 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Wed, 15 May 2024 12:35:20 -0700 Subject: [PATCH 2/5] Add QuickGelu fusion test --- .../tools/transformers/fusion_quickgelu.py | 26 ++++++------------- .../python/transformers/test_gelu_fusions.py | 6 +++++ 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_quickgelu.py b/onnxruntime/python/tools/transformers/fusion_quickgelu.py index 86caf7301d4b9..18c95b20517c5 100644 --- a/onnxruntime/python/tools/transformers/fusion_quickgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_quickgelu.py @@ -14,7 +14,7 @@ class FusionQuickGelu(Fusion): def __init__(self, model: OnnxModel): - super().__init__(model, "QuickGelu", ["MatMul"]) + super().__init__(model, "QuickGelu", ["Mul"]) def fuse(self, node, input_name_to_nodes, output_name_to_node): # Fuse the following subgraph to `QuickGelu` @@ -29,25 +29,15 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): # \ / | # Mul ----+ # | - # MatMul [node] - - second_mul_node = self.model.match_parent_path(node, ["Mul"], [0]) - if second_mul_node is None: - logger.debug("fuse_quickgelu: failed to match second Mul node") - return - second_mul_node = second_mul_node[0] + # root_output - root_input = None - root_input_1 = self.model.match_parent_path(second_mul_node, ["Add"], [0]) - root_input_2 = self.model.match_parent_path(second_mul_node, ["MatMul"], [0]) - if root_input_1 is not None: - root_input = root_input_1[0].output[0] - elif root_input_2 is not None: - root_input = root_input_2[0].output[0] - else: - logger.debug("fuse_quickgelu: failed to match root input") + if node.op_type != "Mul": + logger.debug("fuse_quickgelu: failed to match second Mul node") return + second_mul_node = node + root_input = second_mul_node.input[0] + sigmoid_node = self.model.match_parent_path(second_mul_node, ["Sigmoid"], [1]) if sigmoid_node is None: logger.debug("fuse_quickgelu: failed to match Sigmoid node") @@ -61,7 +51,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): first_mul_node = first_mul_node[0] approximation_value = self.model.get_constant_value(first_mul_node.input[1]).item() - if approximation_value != 1.7021484375: + if abs(approximation_value - 1.7021484375) >= 1e-3: logger.debug("fuse_quickgelu: failed to match approximation value") return diff --git a/onnxruntime/test/python/transformers/test_gelu_fusions.py b/onnxruntime/test/python/transformers/test_gelu_fusions.py index 77a6491d4bd3c..f6c6348ae8c17 100644 --- a/onnxruntime/test/python/transformers/test_gelu_fusions.py +++ b/onnxruntime/test/python/transformers/test_gelu_fusions.py @@ -21,6 +21,11 @@ def forward(self, x): return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) +class HuggingfaceQuickGelu(torch.nn.Module): + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + class MegatronGelu(torch.nn.Module): def forward(self, x): # The original implementation using ones_like, which might cause problem for input with dynamic axes in onnx. @@ -36,6 +41,7 @@ def forward(self, x): test_cases = [ ("huggingface", "Gelu", HuggingfaceGelu), ("huggingface", "FastGelu", HuggingfaceFastGelu), + ("huggingface", "QuickGelu", HuggingfaceQuickGelu), ("megatron", "Gelu", MegatronGelu), ("megatron", "FastGelu", MegatronFastGelu), ] From c50a4aa402c6e5ada85b3b08b518baee8aa3f83a Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Sat, 18 May 2024 05:39:31 +0000 Subject: [PATCH 3/5] Remove commented out code --- onnxruntime/python/tools/transformers/fusion_layernorm.py | 1 - onnxruntime/python/tools/transformers/onnx_model_bert.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_layernorm.py b/onnxruntime/python/tools/transformers/fusion_layernorm.py index a86d6d589e222..4758f21b2292e 100644 --- a/onnxruntime/python/tools/transformers/fusion_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_layernorm.py @@ -64,7 +64,6 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): div_node = div_node_1 elif div_node_2 != []: div_node = div_node_2[-1] - # subgraph_nodes.append(div_node_2[0]) # add Cast node to list of subgraph nodes if div_node is None: return diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 5eef65df19c8c..ad51c1cce0ec4 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -350,8 +350,6 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo if (options is None) or options.enable_rotary_embeddings: self.fuse_rotary_embeddings() - # OnnxModel.save(self.model, "temp.onnx", save_as_external_data=True, all_tensors_to_one_file=True) - if options is not None: self.attention_mask.set_mask_format(options.attention_mask_format) if options.use_multi_head_attention and not isinstance(self.attention_fusion, FusionBartAttention): From 433d3fc3593c4cbec0bfcb1ec00f1736fa861fcf Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Sat, 18 May 2024 05:41:28 +0000 Subject: [PATCH 4/5] Add changes suggested by linter --- .../python/tools/transformers/fusion_attention_clip.py | 2 +- onnxruntime/python/tools/transformers/fusion_quickgelu.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_attention_clip.py b/onnxruntime/python/tools/transformers/fusion_attention_clip.py index c30671e23692c..b027957fcc725 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_clip.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_clip.py @@ -98,7 +98,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # Deal with the first attention after the embedding layer. for i in [0, 1]: node_before_layer_norm = None - + node_before_layer_norm_1 = self.model.match_parent(normalize_node, "Add", i) node_before_layer_norm_2 = self.model.match_parent(normalize_node, "LayerNormalization", i) if node_before_layer_norm_1 is not None: diff --git a/onnxruntime/python/tools/transformers/fusion_quickgelu.py b/onnxruntime/python/tools/transformers/fusion_quickgelu.py index 18c95b20517c5..87154a1c421a8 100644 --- a/onnxruntime/python/tools/transformers/fusion_quickgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_quickgelu.py @@ -3,12 +3,12 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +import logging + from fusion_base import Fusion from onnx import helper from onnx_model import OnnxModel -import logging - logger = logging.getLogger(__name__) @@ -30,7 +30,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): # Mul ----+ # | # root_output - + if node.op_type != "Mul": logger.debug("fuse_quickgelu: failed to match second Mul node") return @@ -43,7 +43,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): logger.debug("fuse_quickgelu: failed to match Sigmoid node") return sigmoid_node = sigmoid_node[0] - + first_mul_node = self.model.match_parent_path(sigmoid_node, ["Mul"], [0]) if first_mul_node is None: logger.debug("fuse_quickgelu: failed to match first Mul node") From 9bc5af31be7c808fcbade5d23d4b4fa3ecddd2ef Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Sat, 18 May 2024 07:08:54 +0000 Subject: [PATCH 5/5] Fix path mismatch check --- onnxruntime/python/tools/transformers/fusion_layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/fusion_layernorm.py b/onnxruntime/python/tools/transformers/fusion_layernorm.py index 4758f21b2292e..678d8c42bad67 100644 --- a/onnxruntime/python/tools/transformers/fusion_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_layernorm.py @@ -62,7 +62,7 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): if div_node_1 is not None: div_node = div_node_1 - elif div_node_2 != []: + elif div_node_2 is not None: div_node = div_node_2[-1] if div_node is None: return