Fix VL model rope_deltas batch size mismatch in online RL training#44873
Fix VL model rope_deltas batch size mismatch in online RL training#44873Cyrilvallez merged 6 commits intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…to qwen3-5-training-fix
|
[For maintainers] Suggested jobs to run (before merge) run-slow: ernie4_5_vl_moe, glm46v, glm4v, glm4v_moe, glm_image, glm_ocr, paddleocr_vl, qwen2_5_vl, qwen2_vl, qwen3_5, qwen3_5_moe, qwen3_vl, qwen3_vl_moe |
Cyrilvallez
left a comment
There was a problem hiding this comment.
Hey @sergiopaniego! LGTM to me, thanks a lot! @zucchini-nlp can you quickly confirm that we won't have any other issues from this before we merge? But looks very logical
zucchini-nlp
left a comment
There was a problem hiding this comment.
Yep, looks reasonable to me.
And it's actually good time to deprecate rope_deltas from instance attr and pass as an arg. I remember we deprecated an arg with Joao, because generation had no proper support for position ids, which isn't the case anymore
What does this PR do?
Problem
Online RL training (GRPO, RLOO, PPO) with all VL models using MRoPE with rope_deltas (Qwen2-VL,
Qwen2.5-VL, Qwen3-VL, Qwen3.5, GLM4V, PaddleOCR-VL, Ernie4.5-VL-MoE, etc.) crashes with
RuntimeError: Sizes of tensors must matchinapply_rotary_pos_emb. This happens becausemodel.generate(batch_size=N)setsself.model.rope_deltasand the subsequent training forward passmodel(batch_size=M)with M != N uses stale deltas, causingrepeat_interleave(M // N)to produce empty tensors.Reproduction
Note: the bug only triggers when
per_device_train_batch_size < num_generations(generation batch != training batch). Withper_device_train_batch_size == num_generationsit works becauserepeat_interleave(N // N) = repeat_interleave(1). Most real training configs useper_device_train_batch_size=1withnum_generations > 1, so this affects the majority of use cases.Fix
Only use pre-calculated
rope_deltasduring incremental generation (past_key_values_length > 0), not during full forward passes. The existing tests already work around this by manually resettingmodel.rope_deltas = Nonebetween calls (lines 429-430, 447-448 intest_modeling_qwen3_5.py).Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker @Cyrilvallez