Skip to content

Fix VL model rope_deltas batch size mismatch in online RL training#44873

Merged
Cyrilvallez merged 6 commits intohuggingface:mainfrom
sergiopaniego:qwen3-5-training-fix
Mar 20, 2026
Merged

Fix VL model rope_deltas batch size mismatch in online RL training#44873
Cyrilvallez merged 6 commits intohuggingface:mainfrom
sergiopaniego:qwen3-5-training-fix

Conversation

@sergiopaniego
Copy link
Copy Markdown
Member

@sergiopaniego sergiopaniego commented Mar 20, 2026

What does this PR do?

Problem

Online RL training (GRPO, RLOO, PPO) with all VL models using MRoPE with rope_deltas (Qwen2-VL,
Qwen2.5-VL, Qwen3-VL, Qwen3.5, GLM4V, PaddleOCR-VL, Ernie4.5-VL-MoE, etc.) crashes with RuntimeError: Sizes of tensors must match in apply_rotary_pos_emb. This happens because model.generate(batch_size=N) sets self.model.rope_deltas and the subsequent training forward pass model(batch_size=M) with M != N uses stale deltas, causing repeat_interleave(M // N) to produce empty tensors.

Reproduction

from trl import GRPOTrainer, GRPOConfig
from datasets import Dataset

trainer = GRPOTrainer(
    model="Qwen/Qwen3.5-2B",
    reward_funcs=lambda completions, **kw: [1.0] * len(completions),
    train_dataset=Dataset.from_dict({"prompt": [[{"role": "user", "content": "Hi"}]] * 8}),
    args=GRPOConfig(num_generations=2, per_device_train_batch_size=1, max_steps=1),
)
trainer.train()  # crashes

Note: the bug only triggers when per_device_train_batch_size < num_generations (generation batch != training batch). With per_device_train_batch_size == num_generations it works because repeat_interleave(N // N) = repeat_interleave(1). Most real training configs use per_device_train_batch_size=1 with num_generations > 1, so this affects the majority of use cases.

Fix

Only use pre-calculated rope_deltas during incremental generation (past_key_values_length > 0), not during full forward passes. The existing tests already work around this by manually resetting model.rope_deltas = None between calls (lines 429-430, 447-448 in test_modeling_qwen3_5.py).

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @Cyrilvallez

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sergiopaniego sergiopaniego changed the title Fix Qwen3.5 rope_deltas batch size mismatch in training forward pass Fix Qwen VL rope_deltas batch size mismatch in online RL training Mar 20, 2026
@sergiopaniego sergiopaniego changed the title Fix Qwen VL rope_deltas batch size mismatch in online RL training The Pope holds a kind of mini-mass on Sundays, leaning out of a window Mar 20, 2026
@sergiopaniego sergiopaniego changed the title The Pope holds a kind of mini-mass on Sundays, leaning out of a window d Mar 20, 2026
@sergiopaniego sergiopaniego changed the title d Fix VL model rope_deltas batch size mismatch in online RL training Mar 20, 2026
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: ernie4_5_vl_moe, glm46v, glm4v, glm4v_moe, glm_image, glm_ocr, paddleocr_vl, qwen2_5_vl, qwen2_vl, qwen3_5, qwen3_5_moe, qwen3_vl, qwen3_vl_moe

Copy link
Copy Markdown
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @sergiopaniego! LGTM to me, thanks a lot! @zucchini-nlp can you quickly confirm that we won't have any other issues from this before we merge? But looks very logical

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, looks reasonable to me.

And it's actually good time to deprecate rope_deltas from instance attr and pass as an arg. I remember we deprecated an arg with Joao, because generation had no proper support for position ids, which isn't the case anymore

@Cyrilvallez Cyrilvallez added this pull request to the merge queue Mar 20, 2026
Merged via the queue into huggingface:main with commit b7164ec Mar 20, 2026
22 checks passed
@sergiopaniego sergiopaniego deleted the qwen3-5-training-fix branch March 20, 2026 13:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants