Skip to content

Update LLaMA attention fusions#19200

Merged
kunal-vaishnavi merged 3 commits intomicrosoft:mainfrom
kunal-vaishnavi:kvaishnavi/llama-fix-attn-mask
Jan 19, 2024
Merged

Update LLaMA attention fusions#19200
kunal-vaishnavi merged 3 commits intomicrosoft:mainfrom
kunal-vaishnavi:kvaishnavi/llama-fix-attn-mask

Conversation

@kunal-vaishnavi
Copy link
Contributor

Description

This PR updates the LLaMA-2 attention fusions by adding the following.

  • Loading the PyTorch model from Hugging Face with the LlamaAttention class before exporting
  • Updating the attention mask pattern matching to support another case

This PR also fixes this issue.

Motivation and Context

Recent changes to Hugging Face's transformers library break the existing pattern matching. Since the attention fusions aim to change the graph from LayerNorm Op --> Set of Attention Nodes --> LayerNorm Op to LayerNorm Op --> Attention Op --> LayerNorm Op per layer, ultimately it does not matter what nodes comprise the Set of Attention Nodes because they will all be removed and replaced by the Attention Op in the end.

Therefore, it does not matter whether the LlamaAttention class or a different attention class is used to load the PyTorch model before exporting because the expected graphs after the attention fusions will look identical no matter the attention class chosen. By loading the PyTorch model with the LlamaAttention class instead of other attention classes (e.g. LlamaFlashAttention2 or LlamaSdpaAttention) and then exporting it to ONNX, the existing pattern matching will continue to work.

RyanUnderhill
RyanUnderhill previously approved these changes Jan 19, 2024
@kunal-vaishnavi kunal-vaishnavi merged commit a3ecb63 into microsoft:main Jan 19, 2024
YUNQIUGUO pushed a commit that referenced this pull request Jan 23, 2024
### Description
This PR updates the LLaMA-2 attention fusions by adding the following.

- Loading the PyTorch model from Hugging Face with the `LlamaAttention`
class before exporting
- Updating the attention mask pattern matching to support another case

This PR also fixes [this
issue](#19040).

### Motivation and Context
Recent changes to Hugging Face's `transformers` library break the
existing pattern matching. Since the attention fusions aim to change the
graph from `LayerNorm Op --> Set of Attention Nodes --> LayerNorm Op` to
`LayerNorm Op --> Attention Op --> LayerNorm Op` per layer, ultimately
it does not matter what nodes comprise the `Set of Attention Nodes`
because they will all be removed and replaced by the `Attention Op` in
the end.

Therefore, it does not matter whether the `LlamaAttention` class or a
different attention class is used to load the PyTorch model before
exporting because the expected graphs after the attention fusions will
look identical no matter the attention class chosen. By loading the
PyTorch model with the `LlamaAttention` class instead of other attention
classes (e.g. `LlamaFlashAttention2` or `LlamaSdpaAttention`) and then
exporting it to ONNX, the existing pattern matching will continue to
work.
YUNQIUGUO pushed a commit that referenced this pull request Jan 30, 2024
### Description
This PR updates the Whisper export with beam search by adding the
following.

- Fixes a bug when running `DecoderMaskedMultiHeadAttention` in the
Whisper with beam search model
- Sets the default PyTorch attention implementation to `eager` to allow
existing attention fusions to continue working
- Re-uses the cache directory when loading the PyTorch model to reduce
memory used on disk
- Adds `--disable_auto_mixed_precision` to the example FP16 export
command

### Motivation and Context
- [This PR](#19112) added
the `is_unidirectional` parameter to `CheckInputs`, but it was not
provided when checking the inputs in `DecoderMaskedMultiHeadAttention`.
- [This PR](#19200)
explains the reasoning behind why `eager` is used to load the
`WhisperAttention` class.
- By re-using the cache directory for loading the PyTorch model, only
one copy of the PyTorch model is saved on disk instead of two copies.
- By providing this flag, there will be less Cast nodes in the Whisper
with beam search model to switch between FP16 and FP32 precision.
YUNQIUGUO pushed a commit that referenced this pull request Jan 30, 2024
### Description
This PR updates the Whisper export with beam search by adding the
following.

- Fixes a bug when running `DecoderMaskedMultiHeadAttention` in the
Whisper with beam search model
- Sets the default PyTorch attention implementation to `eager` to allow
existing attention fusions to continue working
- Re-uses the cache directory when loading the PyTorch model to reduce
memory used on disk
- Adds `--disable_auto_mixed_precision` to the example FP16 export
command

### Motivation and Context
- [This PR](#19112) added
the `is_unidirectional` parameter to `CheckInputs`, but it was not
provided when checking the inputs in `DecoderMaskedMultiHeadAttention`.
- [This PR](#19200)
explains the reasoning behind why `eager` is used to load the
`WhisperAttention` class.
- By re-using the cache directory for loading the PyTorch model, only
one copy of the PyTorch model is saved on disk instead of two copies.
- By providing this flag, there will be less Cast nodes in the Whisper
with beam search model to switch between FP16 and FP32 precision.
rohan11235813 pushed a commit to quadric-io/onnxruntime that referenced this pull request Aug 19, 2025
### Description
This PR updates the Whisper export with beam search by adding the
following.

- Fixes a bug when running `DecoderMaskedMultiHeadAttention` in the
Whisper with beam search model
- Sets the default PyTorch attention implementation to `eager` to allow
existing attention fusions to continue working
- Re-uses the cache directory when loading the PyTorch model to reduce
memory used on disk
- Adds `--disable_auto_mixed_precision` to the example FP16 export
command

### Motivation and Context
- [This PR](microsoft/onnxruntime#19112) added
the `is_unidirectional` parameter to `CheckInputs`, but it was not
provided when checking the inputs in `DecoderMaskedMultiHeadAttention`.
- [This PR](microsoft/onnxruntime#19200)
explains the reasoning behind why `eager` is used to load the
`WhisperAttention` class.
- By re-using the cache directory for loading the PyTorch model, only
one copy of the PyTorch model is saved on disk instead of two copies.
- By providing this flag, there will be less Cast nodes in the Whisper
with beam search model to switch between FP16 and FP32 precision.
@snnn
Copy link
Contributor

snnn commented Sep 5, 2025

This PR has been cherry-picked into the rel-1.17.0 branch in PR #19243. Removing the release:1.17.0 label.

rohan11235813 pushed a commit to quadric-io/onnxruntime that referenced this pull request Sep 15, 2025
### Description
This PR updates the Whisper export with beam search by adding the
following.

- Fixes a bug when running `DecoderMaskedMultiHeadAttention` in the
Whisper with beam search model
- Sets the default PyTorch attention implementation to `eager` to allow
existing attention fusions to continue working
- Re-uses the cache directory when loading the PyTorch model to reduce
memory used on disk
- Adds `--disable_auto_mixed_precision` to the example FP16 export
command

### Motivation and Context
- [This PR](microsoft/onnxruntime#19112) added
the `is_unidirectional` parameter to `CheckInputs`, but it was not
provided when checking the inputs in `DecoderMaskedMultiHeadAttention`.
- [This PR](microsoft/onnxruntime#19200)
explains the reasoning behind why `eager` is used to load the
`WhisperAttention` class.
- By re-using the cache directory for loading the PyTorch model, only
one copy of the PyTorch model is saved on disk instead of two copies.
- By providing this flag, there will be less Cast nodes in the Whisper
with beam search model to switch between FP16 and FP32 precision.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Documentation] Both new LLama-7B examples are now broken

3 participants