-
Notifications
You must be signed in to change notification settings - Fork 166
add fake for MLA RoPE operator #1714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds torch.compile support for the fused_qk_rope_cat_and_cache_mla function by introducing a fake tensor function that simulates tensor shapes and dtypes without actual computation. This is required for SGLang's torch.compile integration.
- Adds
fused_qk_rope_cat_and_cache_mla_fake_tensorfunction to generate fake tensors for torch.compile - Updates return type to always return 5 tensors (including
q_nope_zeros_out) for consistency - Adds type hints and improves type annotations for better code clarity
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
k50112113
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Thanks for the addition, I think we are not going to let torch compile see inside this function in any cases, so this is a pretty decent change.
Motivation
For the function fused_qk_rope_cat_and_cache_mla, SGLang needs fake for it to pass torch compile.
Technical Details
This commit will need another SGLang commit merged simultaneously, because the API is changed.
Test Plan
Test Result
Submission Checklist