Skip to content

Conversation

@gramalingam
Copy link
Collaborator

Add initial support for RotaryEmbedding fusion for onnx opset 23

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
@codecov
Copy link

codecov bot commented Jul 14, 2025

❌ 9 Tests Failed:

Tests completed Failed Passed Skipped
16456 9 16447 3852
View the top 3 failed test(s) by shortest run time
::onnxscript.tools.training_helper
Stack Traces | 0s run time
ImportError while importing test module '.../onnxscript/tools/training_helper.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
.../Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript/tools/training_helper.py:6: in <module>
    from torch.onnx import _OrtBackend, _OrtBackendOptions
E   ImportError: cannot import name '_OrtBackend' from 'torch.onnx' (.../onnxscript/onnxscript/.nox.../test_torch_nightly/lib/python3.11.../torch/onnx/__init__.py)
::onnxscript.tools.transformers_models.llama_test
Stack Traces | 0s run time
ImportError while importing test module '.../tools/transformers_models/llama_test.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
.../Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.../tools/transformers_models/llama_test.py:12: in <module>
    import onnxscript.tools.training_helper
onnxscript/tools/training_helper.py:6: in <module>
    from torch.onnx import _OrtBackend, _OrtBackendOptions
E   ImportError: cannot import name '_OrtBackend' from 'torch.onnx' (.../onnxscript/onnxscript/.nox.../test_torch_nightly/lib/python3.11.../torch/onnx/__init__.py)
::onnxscript.tools.transformers_models.mistral_test
Stack Traces | 0s run time
ImportError while importing test module '.../tools/transformers_models/mistral_test.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
.../Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.../tools/transformers_models/mistral_test.py:14: in <module>
    import onnxscript.tools.training_helper
onnxscript/tools/training_helper.py:6: in <module>
    from torch.onnx import _OrtBackend, _OrtBackendOptions
E   ImportError: cannot import name '_OrtBackend' from 'torch.onnx' (.../onnxscript/onnxscript/.nox.../test_torch_nightly/lib/python3.11.../torch/onnx/__init__.py)

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
@gramalingam gramalingam enabled auto-merge (squash) July 18, 2025 22:42
Comment on lines +13 to +19
# def rotate_half(x):
# """Rotates half the hidden dims of the input."""
# x1 = x[..., : x.shape[-1] // 2]
# x2 = x[..., x.shape[-1] // 2 :]
# return torch.cat((-x2, x1), dim=-1)
# and
# q_embed = (q * cos) + (rotate_half(q) * sin)

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.

Copilot Autofix

AI 6 months ago

The best way to address the issue is to remove the commented-out code and replace it with a concise, well-structured explanation of the referenced logic. The explanation can include a link to the external function's implementation in Hugging Face's repository and a summary of what the function does, ensuring clarity without including raw commented-out code.

Specifically:

  1. Remove the commented-out rotate_half function code (lines 13-20).
  2. Replace it with a concise comment explaining the logic and its relevance to _rotate_half_pattern.
  3. Retain the link to the external repository for further reference.

Suggested changeset 1
onnxscript/rewriter/onnx_fusions/_rotary_embedding.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxscript/rewriter/onnx_fusions/_rotary_embedding.py b/onnxscript/rewriter/onnx_fusions/_rotary_embedding.py
--- a/onnxscript/rewriter/onnx_fusions/_rotary_embedding.py
+++ b/onnxscript/rewriter/onnx_fusions/_rotary_embedding.py
@@ -10,15 +10,11 @@
 
 # Basic pattern: For example, see
 # https://github.com/huggingface/transformers/blob/541bed22d6e4f97946a3a7d74f7e1a353e58643b/src/transformers/models/llama/modeling_llama.py#L104
-#    def rotate_half(x):
-#        """Rotates half the hidden dims of the input."""
-#        x1 = x[..., : x.shape[-1] // 2]
-#        x2 = x[..., x.shape[-1] // 2 :]
-#        return torch.cat((-x2, x1), dim=-1)
-# and
-#        q_embed = (q * cos) + (rotate_half(q) * sin)
+# The Hugging Face implementation includes a function `rotate_half` that splits the input tensor
+# into two halves along the last dimension, rotates one half, and concatenates them back.
+# This logic is used in operations like `q_embed = (q * cos) + (rotate_half(q) * sin)`.
+# The `_rotate_half_pattern` function below implements equivalent functionality using ONNX ops.
 
-
 def _rotate_half_pattern(op, x, start1, end1, start2, end2):
     # Slice(input, starts, ends, axes, steps)
     x1 = op.Slice(x, start1, end1, [3], [1])
EOF
@@ -10,15 +10,11 @@

# Basic pattern: For example, see
# https://github.com/huggingface/transformers/blob/541bed22d6e4f97946a3a7d74f7e1a353e58643b/src/transformers/models/llama/modeling_llama.py#L104
# def rotate_half(x):
# """Rotates half the hidden dims of the input."""
# x1 = x[..., : x.shape[-1] // 2]
# x2 = x[..., x.shape[-1] // 2 :]
# return torch.cat((-x2, x1), dim=-1)
# and
# q_embed = (q * cos) + (rotate_half(q) * sin)
# The Hugging Face implementation includes a function `rotate_half` that splits the input tensor
# into two halves along the last dimension, rotates one half, and concatenates them back.
# This logic is used in operations like `q_embed = (q * cos) + (rotate_half(q) * sin)`.
# The `_rotate_half_pattern` function below implements equivalent functionality using ONNX ops.


def _rotate_half_pattern(op, x, start1, end1, start2, end2):
# Slice(input, starts, ends, axes, steps)
x1 = op.Slice(x, start1, end1, [3], [1])
Copilot is powered by AI and may make mistakes. Always verify output.
@justinchuby
Copy link
Collaborator

Looks like there is merge conflicts

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
@gramalingam
Copy link
Collaborator Author

Looks like there is merge conflicts

Resolved

@gramalingam gramalingam merged commit c33fce2 into main Jul 21, 2025
25 of 32 checks passed
@gramalingam gramalingam deleted the rama/rotary branch July 21, 2025 19:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Development

Successfully merging this pull request may close these issues.

3 participants