Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 60 additions & 21 deletions onnxruntime/python/tools/transformers/fusion_attention_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,33 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
else:
# Deal with the first attention after the embedding layer.
for i in [0, 1]:
node_before_layer_norm = self.model.match_parent(normalize_node, "Add", i)
node_before_layer_norm = None

node_before_layer_norm_1 = self.model.match_parent(normalize_node, "Add", i)
node_before_layer_norm_2 = self.model.match_parent(normalize_node, "LayerNormalization", i)
if node_before_layer_norm_1 is not None:
# Add -----------+
# | |
# LayerNorm |
# | |
# LayerNorm |
# | |
# Attention subgraph |
# | |
# SkipLayerNorm ------+
node_before_layer_norm = node_before_layer_norm_1
elif node_before_layer_norm_2 is not None:
# Add
# |
# LayerNorm --------+
# | |
# LayerNorm |
# | |
# Attention subgraph |
# | |
# SkipLayerNorm ------+
node_before_layer_norm = node_before_layer_norm_2

if node_before_layer_norm is None:
continue
child = self.model.find_first_child_by_type(
Expand Down Expand Up @@ -130,20 +156,32 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
return
(_, _, reshape_v, add_v, matmul_v) = v_nodes

add_mask = None
add_mask_indices = []
qk_nodes = self.model.match_parent_path(
qk_nodes = None
qk_nodes_1 = self.model.match_parent_path(
matmul_qkv,
["Softmax", "Reshape", "Add", "Reshape", "MatMul"],
[0, 0, 0, None, 0],
return_indice=add_mask_indices,
)
if qk_nodes is None:
qk_nodes_2 = self.model.match_parent_path(
matmul_qkv,
["Softmax", "MatMul"],
[0, 0],
)
if qk_nodes_1 is not None:
qk_nodes = qk_nodes_1
assert len(add_mask_indices) == 1
causal_mask_input_index = 1 - add_mask_indices[0]

(_softmax_qk, _, add_mask, _, matmul_qk) = qk_nodes
elif qk_nodes_2 is not None:
qk_nodes = qk_nodes_2
(_softmax_qk, matmul_qk) = qk_nodes
else:
logger.debug("fuse_attention: failed to match qk path")
return
assert len(add_mask_indices) == 1
causal_mask_input_index = 1 - add_mask_indices[0]

(_softmax_qk, _, add_mask, _, matmul_qk) = qk_nodes

q_nodes = self.model.match_parent_path(
matmul_qk, ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, None, None]
Expand Down Expand Up @@ -172,23 +210,24 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):

attention_last_node = reshape_qkv

# Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path
# of computing causal mask.
causal_mask_nodes = self.model.match_parent_path(
add_mask,
["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0, 0],
)
if causal_mask_nodes is None:
# If the model is exported with batch_size == 1, there is no Concat node
if add_mask is not None:
# Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path
# of computing causal mask.
causal_mask_nodes = self.model.match_parent_path(
add_mask,
["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0],
["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0, 0],

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable

Local variable 'causal_mask_input_index' may be used before it is initialized.
)
if causal_mask_nodes is None:
logger.debug("fuse_attention: failed to match causal mask subgraph")
return
# If the model is exported with batch_size == 1, there is no Concat node
causal_mask_nodes = self.model.match_parent_path(
add_mask,
["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0],
)
if causal_mask_nodes is None:
logger.debug("fuse_attention: failed to match causal mask subgraph")
return

new_node = self.create_attention_node(
mask_index=None,
Expand All @@ -204,7 +243,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
output=attention_last_node.output[0],
add_qk_str=None,
scale=None,
causal=True,
causal=(add_mask is not None),
)
if new_node is None:
return
Expand Down
33 changes: 23 additions & 10 deletions onnxruntime/python/tools/transformers/fusion_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
| |
+----------------------+
"""
subgraph_nodes = []
children = self.model.get_children(node, input_name_to_nodes)
if len(children) == 0 or len(children) > 2:
return
Expand All @@ -53,20 +54,24 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):

div_node = None
for child in children:
div_node = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)
if div_node is not None:
break
# Check if Sub --> Div exists
div_node_1 = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)

# Check if Sub --> Cast --> Div
div_node_2 = self.model.match_child_path(child, ["Cast", "Div"], exclude=[])

if div_node_1 is not None:
div_node = div_node_1
elif div_node_2 is not None:
div_node = div_node_2[-1]
if div_node is None:
return

path_id, parent_nodes, _ = self.model.match_parent_paths(
div_node,
[
(["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]),
(
["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"],
[1, 0, 0, 0, 0, 0],
),
(["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]),
],
output_name_to_node,
)
Expand All @@ -87,15 +92,22 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
if self.model.find_constant_input(pow_node, 2.0) != 1:
return

mul_node = input_name_to_nodes[div_node.output[0]][0]
temp_node = input_name_to_nodes[div_node.output[0]][0]
if temp_node.op_type == "Cast":
# Div --> Cast --> Mul
subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes
mul_node = input_name_to_nodes[temp_node.output[0]][0]
else:
# Div --> Mul
mul_node = temp_node
if mul_node.op_type != "Mul":
return

last_add_node = input_name_to_nodes[mul_node.output[0]][0]
if last_add_node.op_type != "Add":
return

subgraph_nodes = [node]
subgraph_nodes.append(node)
subgraph_nodes.extend(children)
subgraph_nodes.extend(parent_nodes[:-1])

Expand All @@ -109,7 +121,8 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
logger.debug("It is not safe to fuse LayerNormalization node. Skip")
return

weight_input = mul_node.input[1 - self.model.input_index(div_node.output[0], mul_node)]
node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node
weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)]
if not self.model.is_constant_with_specified_dimension(weight_input, 1, "layernorm weight"):
return

Expand Down
74 changes: 74 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_quickgelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import logging

from fusion_base import Fusion
from onnx import helper
from onnx_model import OnnxModel

logger = logging.getLogger(__name__)


class FusionQuickGelu(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "QuickGelu", ["Mul"])

def fuse(self, node, input_name_to_nodes, output_name_to_node):
# Fuse the following subgraph to `QuickGelu`
#
# root_input
# / \
# | Mul ----+
# | (B = ~1.702) |
# \ | |
# \ Sigmoid |---- `QuickGelu`
# \ / |
# \ / |
# Mul ----+
# |
# root_output

if node.op_type != "Mul":
logger.debug("fuse_quickgelu: failed to match second Mul node")
return

second_mul_node = node
root_input = second_mul_node.input[0]

sigmoid_node = self.model.match_parent_path(second_mul_node, ["Sigmoid"], [1])
if sigmoid_node is None:
logger.debug("fuse_quickgelu: failed to match Sigmoid node")
return
sigmoid_node = sigmoid_node[0]

first_mul_node = self.model.match_parent_path(sigmoid_node, ["Mul"], [0])
if first_mul_node is None:
logger.debug("fuse_quickgelu: failed to match first Mul node")
return
first_mul_node = first_mul_node[0]

approximation_value = self.model.get_constant_value(first_mul_node.input[1]).item()
if abs(approximation_value - 1.7021484375) >= 1e-3:
logger.debug("fuse_quickgelu: failed to match approximation value")
return

if first_mul_node.input[0] != root_input:
logger.debug("fuse_quickgelu: failed to match root input with first Mul node's input")
return

new_node = helper.make_node(
"QuickGelu",
inputs=[root_input],
outputs=[second_mul_node.output[0]],
name=self.model.create_node_name("QuickGelu"),
)
new_node.domain = "com.microsoft"
new_node.attribute.extend([helper.make_attribute("alpha", approximation_value)])

self.nodes_to_remove.extend([first_mul_node, sigmoid_node, second_mul_node])
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
self.increase_counter("QuickGelu")
3 changes: 3 additions & 0 deletions onnxruntime/python/tools/transformers/onnx_model_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from fusion_qordered_gelu import FusionQOrderedGelu
from fusion_qordered_layernorm import FusionQOrderedLayerNormalization
from fusion_qordered_matmul import FusionQOrderedMatMul
from fusion_quickgelu import FusionQuickGelu
from fusion_reshape import FusionReshape
from fusion_rotary_attention import FusionRotaryEmbeddings
from fusion_shape import FusionShape
Expand Down Expand Up @@ -65,6 +66,8 @@ def fuse_gelu(self):
fusion.apply()
fusion = FusionFastGelu(self)
fusion.apply()
fusion = FusionQuickGelu(self)
fusion.apply()
# Only relevant in models with Q-DQ nodes
fusion = FusionQOrderedGelu(self)
fusion.apply()
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/python/tools/transformers/onnx_model_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def get_fused_operator_statistics(self):
ops = [
"Attention",
"LayerNormalization",
"QuickGelu",
"SkipLayerNormalization",
]
for op in ops:
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/test/python/transformers/test_gelu_fusions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def forward(self, x):
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))


class HuggingfaceQuickGelu(torch.nn.Module):
def forward(self, x):
return x * torch.sigmoid(1.702 * x)


class MegatronGelu(torch.nn.Module):
def forward(self, x):
# The original implementation using ones_like, which might cause problem for input with dynamic axes in onnx.
Expand All @@ -36,6 +41,7 @@ def forward(self, x):
test_cases = [
("huggingface", "Gelu", HuggingfaceGelu),
("huggingface", "FastGelu", HuggingfaceFastGelu),
("huggingface", "QuickGelu", HuggingfaceQuickGelu),
("megatron", "Gelu", MegatronGelu),
("megatron", "FastGelu", MegatronFastGelu),
]
Expand Down