Support Gemma3 with Clip fused attention#24280
Merged
titaiwangms merged 15 commits intomainfrom Apr 4, 2025
Merged
Conversation
Contributor
There was a problem hiding this comment.
Pull Request Overview
This PR refactors and extends model tracing and fusion tests for the Gemma3 vision model while adding support for fp16 and fp32 tracing patterns and generalizing input indices for op.Add and op.MatMul. Key changes include:
- Introducing a new traced pattern for CLIP attention without an attention mask.
- Generalizing the input indices for op.Add and op.MatMul and differentiating tracing for fp16 and fp32.
- Refactoring test files to support dynamo export and adding new tests for Gemma3 vision attention (SigLip).
Reviewed Changes
Copilot reviewed 6 out of 8 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| onnxruntime/test/python/transformers/test_gemma3_vision.py | Adds tests and model definitions for Gemma3 vision attention and layer normalization. |
| onnxruntime/test/python/transformers/test_gelu_fusions.py | Refactors Gelu fusion tests by parameterizing the tests for dynamo export. |
| onnxruntime/python/tools/transformers/fusion_fastgelu.py | Improves robustness in FastGelu fusion by handling cases where the root input comes directly from the graph input. |
| onnxruntime/python/tools/transformers/fusion_attention_clip.py | Generalizes pattern matching for attention fusion to support different index configurations and tensor formats. |
Files not reviewed (2)
- tools/ci_build/github/linux/python/requirements.txt: Language not supported
- tools/ci_build/github/windows/python/requirements.txt: Language not supported
Comments suppressed due to low confidence (2)
onnxruntime/python/tools/transformers/fusion_attention_clip.py:152
- [nitpick] Consider using a named constant or adding a comment to clarify the purpose of 'None' as a wildcard in the index array for pattern matching.
[1, None, 0, 0, 0],
onnxruntime/python/tools/transformers/fusion_attention_clip.py:232
- [nitpick] Consistently document or use a named constant for wildcard indices (such as 'None') to improve code clarity in pattern matching.
q_nodes = self.model.match_parent_path(
tianleiwu
previously approved these changes
Apr 2, 2025
kunal-vaishnavi
previously approved these changes
Apr 2, 2025
e7d6160
tianleiwu
previously approved these changes
Apr 3, 2025
kunal-vaishnavi
previously approved these changes
Apr 3, 2025
ee96271
snnn
approved these changes
Apr 4, 2025
Contributor
snnn
left a comment
There was a problem hiding this comment.
Then changes under tools/ci_build are fine.
kunal-vaishnavi
approved these changes
Apr 4, 2025
zhaoxul-qti
pushed a commit
to CodeLinaro/onnxruntime
that referenced
this pull request
Apr 17, 2025
### Description <!-- Describe your changes. --> Essentially, the vision model is traced differently (this time it's without mask.), and the input indices of op.Add and op.MatMul can be different. Also, fp16 and fp32 need different tracing patterns (op.Cast). 1. Add another traced pattern to CLIP attention to cover no attention_mask case 2. Accept different index of input on op.Add and op.MatMul (be more general) 3. fp16 and fp32 shows different pattern (op.Cast after op.Softmax) 4. Refactor test_fastgelu.py to cover torch.onnx.export(..., dynamo=True) 5. Add gemma3 vision attention (SigLip) test to cover both fp16 and fp32 ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> To optimize Gemma3 multi-modal model, the changes are needed. https://huggingface.co/google/gemma-3-4b-it NOTE: some related follow-ups (upstream optimizations to onnxscript-optimizer): microsoft/onnxscript#2158 microsoft/onnxscript#2156
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Essentially, the vision model is traced differently (this time it's without mask.), and the input indices of op.Add and op.MatMul can be different. Also, fp16 and fp32 need different tracing patterns (op.Cast).
Motivation and Context
To optimize Gemma3 multi-modal model, the changes are needed. https://huggingface.co/google/gemma-3-4b-it
NOTE: some related follow-ups (upstream optimizations to onnxscript-optimizer):
microsoft/onnxscript#2158
microsoft/onnxscript#2156