Skip to content

Fix IndexError with DeepSpeed ZeRO-3 when kernels rotary is active#45395

Closed
ArthurZucker wants to merge 3 commits intohuggingface:mainfrom
ArthurZucker:worktree-fix-rotary-fn-zero3-45137
Closed

Fix IndexError with DeepSpeed ZeRO-3 when kernels rotary is active#45395
ArthurZucker wants to merge 3 commits intohuggingface:mainfrom
ArthurZucker:worktree-fix-rotary-fn-zero3-45137

Conversation

@ArthurZucker
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker commented Apr 13, 2026

Summary

Fixes #45137.

Since #41147, attention layers are decorated with @use_kernelized_func(apply_rotary_pos_emb) which attaches a rotary_fn child nn.Module at init when the kernels library is available.

DeepSpeed ZeRO-3's parameter coordinator traces the module graph at init and expects every registered submodule to run during forward. The attention forward still calls the Python apply_rotary_pos_emb, so rotary_fn is never invoked and the parameter-fetch trace desynchronizes, raising:

IndexError: pop from an empty deque
  at deepspeed/runtime/zero/partitioned_param_coordinator.py

on the second forward (reproducible via TRL's RLOO/GRPO trainers under ZeRO-3, see huggingface/trl#4899).

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker requested a review from vasqu April 13, 2026 09:07
When `kernels` is installed, `@use_kernelized_func` attaches a
`rotary_fn` child `nn.Module` to attention layers. DeepSpeed ZeRO-3's
parameter coordinator traces the module graph at init and expects
every registered submodule to be invoked during forward. The model's
forward still calls the plain Python `apply_rotary_pos_emb`, so
`rotary_fn` is never executed and the trace desynchronizes, raising
`IndexError: pop from an empty deque` on the second forward.

Skip attaching the kernelized submodule when ZeRO-3 is enabled; users
running under ZeRO-3 fall back to the Python implementation, which is
what they were getting before huggingface#41147.

Fixes huggingface#45137
@ArthurZucker ArthurZucker force-pushed the worktree-fix-rotary-fn-zero3-45137 branch from f7b48c5 to bc4d35d Compare April 13, 2026 10:05
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only real question is: Deepspeed will no longer use a kernelized RoPE or does it?

(Also what's up with CI 👀)


-->
*This model was released on 2023-02-06 and added to Hugging Face Transformers on 2026-03-19.*
*This model was released on 2023-02-06 and added to Hugging Face Transformers on 2026-03-21.*
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated? Not a big deal either way

Comment on lines +460 to +463
from .deepspeed import is_deepspeed_zero3_enabled

if is_deepspeed_zero3_enabled():
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean the kernels version of RoPE cannot be used with deepspeed this way?

The whole kernelizing processes seems to go through multiple paths and it's a bit hard to keep track of:

  • Kernelized RoPE only happens later (after init?)
  • We register it under the module (as it's a requirement for kernels ig)

It should be fine as quick fix for sure

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its the way we do functtion kernelize, (today only rope)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but yeah we do at init as module, only triggered lazily I think

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 13, 2026

I thought a bit more about this and it might seem that this whole approach is a bit outdated? Imo, it would make more sense to use the lazy load approach as in e.g. Mamba:

global causal_conv1d, causal_conv1d_update, causal_conv1d_fn
causal_conv1d = lazy_load_kernel("causal-conv1d")
causal_conv1d_update = getattr(causal_conv1d, "causal_conv1d_update", None)
causal_conv1d_fn = getattr(causal_conv1d, "causal_conv1d_fn", None)
global mamba_ssm, selective_state_update, selective_scan_fn, mamba_inner_fn
mamba_ssm = lazy_load_kernel("mamba-ssm")
selective_state_update = resolve_internal_import(
mamba_ssm, chained_path="ops.triton.selective_state_update.selective_state_update"
)
selective_scan_fn = getattr(mamba_ssm, "selective_scan_fn", None)
mamba_inner_fn = getattr(mamba_ssm, "mamba_inner_fn", None)

No decorator anymore but keeping it in global variable state. I'm surprised either way that we had to set this as attribute to be honest 😬

@ArthurZucker
Copy link
Copy Markdown
Collaborator Author

ArthurZucker commented Apr 13, 2026

Hhaha yeah its not outdated, that's the issue. It's "new":
This is how it would look like with your comment: #41145
vs
The PR that introduced it #41147! Reason was: #41147 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

IndexError: pop from an empty deque with DeepSpeed ZeRO3

3 participants