Skip to content

Prepare and keep track of position ids in generate#43734

Merged
zucchini-nlp merged 26 commits intohuggingface:mainfrom
zucchini-nlp:position-ids-precompute-once-per-generate
Feb 12, 2026
Merged

Prepare and keep track of position ids in generate#43734
zucchini-nlp merged 26 commits intohuggingface:mainfrom
zucchini-nlp:position-ids-precompute-once-per-generate

Conversation

@zucchini-nlp
Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp commented Feb 4, 2026

What does this PR do?

As per title. lays ground to unifying 3D position ids in qwen-style VLMs

PR adds a single entrypoint to prepare position ids in GenerationMixin which models can override if needed (qwen-vl for ex). This allow users to prepare their own position ids and pass them to generate(). In decoding stages, the position ids are simply incremented by one to build the next positions

Along with it, PR starts a light unification on 3D positions by splitting it into its own utility fn. Now we have only two or three models with their own compute_3d_positions and all other models copy from there. In the next PR, I will split get_rope_index into smaller components allowing us to copy similarities easily. I am working on it locally but it's blocked by current branch

Review starting from transformers/generation and models from which we copy (qwen2-vl and ernie4_5_vl)

Fixes #29149

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Feb 4, 2026

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43734&sha=929e2c

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@zucchini-nlp
Copy link
Copy Markdown
Member Author

run-slow: colqwen2, ernie4_5_vl_moe, gemma3, glm46v, glm4v, glm4v_moe, glm_image, glm_ocr, gpt_neo, paddleocr_vl, qwen2_5_vl, qwen2_vl, qwen3_vl, qwen3_vl_moe, reformer, video_llama_3

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Feb 4, 2026

This comment contains run-slow, running the specified jobs:

models: ["models/colqwen2", "models/ernie4_5_vl_moe", "models/gemma3", "models/glm46v", "models/glm4v", "models/glm4v_moe", "models/glm_image", "models/glm_ocr", "models/gpt_neo", "models/paddleocr_vl", "models/qwen2_5_vl", "models/qwen2_vl", "models/qwen3_vl", "models/qwen3_vl_moe", "models/reformer", "models/video_llama_3"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Feb 4, 2026

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 692a6f4e merge commit
PR 819cecb1 branch commit
main 452c179e base commit

✅ No failing test specific to this PR 🎉 👏 !

@zucchini-nlp
Copy link
Copy Markdown
Member Author

run-slow: colqwen2, ernie4_5_vl_moe, gemma3, glm46v, glm4v, glm4v_moe, glm_image, glm_ocr, gpt_neo, paddleocr_vl, qwen2_5_vl, qwen2_vl, qwen3_vl, qwen3_vl_moe, reformer, video_llama_3

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Feb 6, 2026

This comment contains run-slow, running the specified jobs:

models: ["models/colqwen2", "models/ernie4_5_vl_moe", "models/gemma3", "models/glm46v", "models/glm4v", "models/glm4v_moe", "models/glm_image", "models/glm_ocr", "models/gpt_neo", "models/paddleocr_vl", "models/qwen2_5_vl", "models/qwen2_vl", "models/qwen3_vl", "models/qwen3_vl_moe", "models/reformer", "models/video_llama_3"]
quantizations: []

Comment on lines -656 to +657
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids.masked_fill_(attention_mask == 0, 0)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't make diff which value we use, because the token is masked anyway. Using 0 makes more sense because when the seq has only one unmasked token, we are getting position ids with max value of 1, not 0

Comment on lines -669 to 670
model_input = model_input[:, -current_input_length:]
model_input = model_input[..., -current_input_length:]
model_input = model_input.clone(memory_format=torch.contiguous_format)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3D positions support

Comment on lines -163 to -169
position_ids, rope_deltas = self.vlm.model.get_rope_index(
input_ids=input_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=None,
attention_mask=attention_mask,
)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we called it here, no idea. Better to let self.vlm handle everything

Comment on lines -1248 to -1267
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1)
.expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1],
device=input_ids.device,
dtype=input_ids.dtype,
)

return position_ids, mrope_position_deltas
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thing as it was for all get_rope_index, just deleted this part

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Feb 6, 2026

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 20268214 merge commit
PR b3a9cb61 branch commit
main 49dd2979 base commit

✅ No failing test specific to this PR 🎉 👏 !

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just left a few questions / nits, I have a feeling we can use modular a tad more re compute_3d_position_ids?

Comment thread src/transformers/generation/candidate_generator.py Outdated
Comment thread src/transformers/generation/candidate_generator.py Outdated
Comment thread src/transformers/generation/utils.py
Comment thread src/transformers/generation/utils.py
Comment thread src/transformers/generation/utils.py
Comment thread src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py
image_outputs.pooler_output = image_embeds
return image_outputs

def compute_3d_position_ids(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we not inherit from Qwen2_5_VLModel? Or is there something specific, let's avoid rewriting where possible

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ernie uses mm_token_type_ids but Qwen2-5VL has second_grid_ts. We can do it if we hide extra kwargs as **kwargs, which basically will look like the above comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yea that's a good point. On another note, do we want to change the other VLMs to use mm token type ids here? Iiirc, it's much faster(?)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do! That is my next PR after this one is merged. I have some stuff locally :)

Comment thread src/transformers/models/gpt_neo/modeling_gpt_neo.py
Comment thread tests/generation/test_utils.py
Comment thread tests/generation/test_utils.py
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Feb 9, 2026

run slow might have been broken so better to rerun after merging with main 👀

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to approve, can you also check if qwen3_5(_moe) have inherited as expected? They are essentially coming from qwen3 vl

@zucchini-nlp
Copy link
Copy Markdown
Member Author

Oke, will run a few slow tests and merge

@zucchini-nlp
Copy link
Copy Markdown
Member Author

run-slow: colqwen2, ernie4_5_vl_moe, gemma3, glm46v, glm4v, glm4v_moe, glm_image, glm_ocr, gpt_neo, idefics, paddleocr_vl, qwen2_5_vl, qwen2_vl, qwen3_vl, qwen3_vl_moe, reformer

@github-actions
Copy link
Copy Markdown
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/colqwen2", "models/ernie4_5_vl_moe", "models/gemma3", "models/glm46v", "models/glm4v", "models/glm4v_moe", "models/glm_image", "models/glm_ocr", "models/gpt_neo", "models/idefics", "models/paddleocr_vl", "models/qwen2_5_vl", "models/qwen2_vl", "models/qwen3_vl", "models/qwen3_vl_moe", "models/reformer"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 6fcbb100 merge commit
PR 3771d007 branch commit
main 44f92b63 base commit

Model CI Report

7 new failed tests from this PR 😭

  • colqwen2:
    tests/models/colqwen2/test_modeling_colqwen2.py::ColQwen2ForRetrievalModelTest::test_torch_export

  • ernie4_5_vl_moe:
    tests/models/ernie4_5_vl_moe/test_modeling_ernie4_5_vl_moe.py::Ernie4_5_VL_MoeSmallIntegrationTest::test_small_model_integration_test
    tests/models/ernie4_5_vl_moe/test_modeling_ernie4_5_vl_moe.py::Ernie4_5_VL_MoeSmallIntegrationTest::test_small_model_integration_test_batch
    tests/models/ernie4_5_vl_moe/test_modeling_ernie4_5_vl_moe.py::Ernie4_5_VL_MoeSmallIntegrationTest::test_small_model_integration_test_expand
    tests/models/ernie4_5_vl_moe/test_modeling_ernie4_5_vl_moe.py::Ernie4_5_VL_MoeSmallIntegrationTest::test_small_model_integration_test_with_video

  • qwen2_5_vl:
    tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py::Qwen2_5_VLIntegrationTest::test_small_model_integration_test_batch_different_resolutions
    tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py::Qwen2_5_VLIntegrationTest::test_small_model_integration_test_batch_wo_image

@Cyrilvallez
Copy link
Copy Markdown
Member

Hey! Sorry I'm a bit late to the party! Instead of having a dedicated _prepare_position_ids_for_generation method everywhere for a lot of models, wdyt about simply giving the responsability of creating position_ids to the model during the first forward? Similar to what we do with caches. That way, we don't have to maintain such additional methods, and generate reuse the prepared position_ids after first forward
@zucchini-nlp @vasqu

@zucchini-nlp
Copy link
Copy Markdown
Member Author

If we make each model compute their position ids in forward (which already happens now in not-so-correct way), we can't just build upon it by incrementing to the next position. Position ids aren't returned from model like cache so we have to start returning them from forward to be able to re-use. Otherwise we just have to let each model re-compute positions from scratch every time, basically what happens now

Actually, I thought at first to make each BaseModel have their own method to compute_position_ids and generation simply calls base_model.compute_position_ids. Yet all models compute it the same way except for qwen-vl and paligemma, so why not just get it in generation mixin and override is special cases

@zucchini-nlp
Copy link
Copy Markdown
Member Author

7 new failed tests from this PR 😭

Ah that was the issue with padding side, supposed to be "left". We don't recompute positions every time thus it doesn't work well with right padding

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Feb 11, 2026

Updated on the hub! So ernie should behave somewhat normally now, no idea why the default padding changed tbh

@zucchini-nlp
Copy link
Copy Markdown
Member Author

run-slow: colqwen2, ernie4_5_vl_moe, qwen2_5_vl

@github-actions
Copy link
Copy Markdown
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/colqwen2", "models/ernie4_5_vl_moe", "models/qwen2_5_vl"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 287c9e43 merge commit
PR 423e4b86 branch commit
main b52b6631 base commit

Model CI Report

5 new failed tests from this PR 😭

  • ernie4_5_vl_moe:
    tests/models/ernie4_5_vl_moe/test_modeling_ernie4_5_vl_moe.py::Ernie4_5_VL_MoeSmallIntegrationTest::test_small_model_integration_test
    tests/models/ernie4_5_vl_moe/test_modeling_ernie4_5_vl_moe.py::Ernie4_5_VL_MoeSmallIntegrationTest::test_small_model_integration_test_batch
    tests/models/ernie4_5_vl_moe/test_modeling_ernie4_5_vl_moe.py::Ernie4_5_VL_MoeSmallIntegrationTest::test_small_model_integration_test_expand
    tests/models/ernie4_5_vl_moe/test_modeling_ernie4_5_vl_moe.py::Ernie4_5_VL_MoeSmallIntegrationTest::test_small_model_integration_test_with_video

  • qwen2_5_vl:
    tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py::Qwen2_5_VLIntegrationTest::test_small_model_integration_test_batch_wo_image

@zucchini-nlp
Copy link
Copy Markdown
Member Author

run-slow: ernie4_5_vl_moe, qwen2_5_vl

@github-actions
Copy link
Copy Markdown
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/ernie4_5_vl_moe", "models/qwen2_5_vl"]
quantizations: []

@Cyrilvallez
Copy link
Copy Markdown
Member

Position ids aren't returned from model like cache

Arhhh, you're right 🥲 Nevermind then!

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN fb65d6f2 merge commit
PR 093c2329 branch commit
main ac6cba66 base commit

✅ No failing test specific to this PR 🎉 👏 !

@zucchini-nlp zucchini-nlp enabled auto-merge (squash) February 12, 2026 08:20
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: colqwen2, ernie4_5_vl_moe, gemma3, glm46v, glm4v, glm4v_moe, glm_image, glm_ocr, gpt_neo, paddleocr_vl, qwen2_5_vl, qwen2_vl, qwen3_5, qwen3_5_moe, qwen3_vl, qwen3_vl_moe

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Generate: support passing position_ids

4 participants