diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index 433c10e504..e2987cfc5e 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -37,13 +37,11 @@ def __init__( self, name, *, - double_transpose: bool, is_rotary: bool, has_past_present: bool, is_cross_attention: bool, ): super().__init__(name) - self._double_transpose = double_transpose self._is_rotary = is_rotary self._has_past_present = has_past_present self._is_cross_attention = is_cross_attention @@ -345,12 +343,10 @@ def rewrite( def _make_rule_set(has_past_present: bool): parameter_combinations = [ { - "double_transpose": double_transpose, "is_rotary": is_rotary, "has_past_present": has_past_present, "is_cross_attention": is_cross_attention, } - for double_transpose in [False, True] for is_rotary in [False, True] for is_cross_attention in ([False] if has_past_present else [False, True]) ] @@ -360,7 +356,6 @@ def _make_rule_set(has_past_present: bool): [ MultiHeadAttention.rule( f"MHA" - f"{'_Twice' if params['double_transpose'] else ''}" f"{'_Rotary' if params['is_rotary'] else ''}" f"{'_Past' if params['has_past_present'] else ''}" f"{'_CrossAttention' if params['is_cross_attention'] else ''}",