Skip to content

Extract dynamic vision/audio tensors into standalone pure functions#45396

Open
IlyasMoutawwakil wants to merge 63 commits intomainfrom
hf-vision-audio-utils
Open

Extract dynamic vision/audio tensors into standalone pure functions#45396
IlyasMoutawwakil wants to merge 63 commits intomainfrom
hf-vision-audio-utils

Conversation

@IlyasMoutawwakil
Copy link
Copy Markdown
Member

@IlyasMoutawwakil IlyasMoutawwakil commented Apr 13, 2026

needed both claude and copilot's help on this one 😅 The idea is to make the vlms and their visual/audio encders compileable / exportable. here's a demo of the full model forward being compileable with these precomputed tensors.

"""Demo: precomputed vision tensors enable torch.compile.

Without precomputation the vision encoder uses loops / .tolist() / repeat_interleave
that break torch.compile fullgraph mode. By computing ``image_cu_seqlens`` and
``image_position_ids`` via ``transformers.vision_utils`` outside the traced region
and passing them into the model, the vision path becomes compile-friendly.

These precomputed tensors are a power-feature: they are not in any public ``forward()``
signature. They flow through ``**kwargs`` from ``model(...)`` down to the vision
encoder, where ``@handle_extra_kwargs(modality="image")`` on ``get_image_features``
strips the ``image_`` prefix so the encoder sees them as ``cu_seqlens``/``position_ids``
in its kwargs. The encoder then calls ``get_vision_cu_seqlens(grid_thw, kwargs=kwargs)``
(and the analogous ``get_vision_position_ids(...)``); each util pops its key from the
caller's kwargs and returns the precomputed tensor, falling back to the (uncompilable)
compute path only if absent.
"""

import torch

from transformers import AutoModelForImageTextToText, AutoProcessor, set_seed
from transformers.vision_utils import get_vision_cu_seqlens, get_vision_position_ids


set_seed(42, deterministic=True)

model_id = "Qwen/Qwen2-VL-2B-Instruct"

print("Loading model and processor...")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForImageTextToText.from_pretrained(model_id, dtype=torch.float32, device_map=device).eval()
processor = AutoProcessor.from_pretrained(model_id)

# --- 1. Prepare a multimodal input ---
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
            },
            {"type": "text", "text": "Describe this image."},
        ],
    }
]
text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)

# --- 2. Processor output (baseline) ---
inputs = processor(text=text, images=messages[0]["content"][0]["url"], return_tensors="pt").to(model.device)
print(f"Processor keys: {sorted(inputs.keys())}")

# --- 3. Precompute vision tensors from grid_thw using vision_utils ---
# These replace the untraceable loops inside the vision encoder. The kwargs are
# `image_*`-prefixed at the model boundary; `@handle_extra_kwargs(modality="image")`
# on `get_image_features` strips the prefix before forwarding to the encoder.
inputs_extra = {**inputs}
spatial_merge_size = model.config.vision_config.spatial_merge_size
inputs_extra["image_cu_seqlens"] = get_vision_cu_seqlens(inputs["image_grid_thw"])
inputs_extra["image_position_ids"] = get_vision_position_ids(inputs["image_grid_thw"], spatial_merge_size)
print(f"image_cu_seqlens shape: {inputs_extra['image_cu_seqlens'].shape}")
print(f"image_position_ids shape: {inputs_extra['image_position_ids'].shape}")

# Precompute 3D text position_ids (M-RoPE) so the model skips get_rope_index at runtime
inputs_extra["position_ids"], _ = model.model.get_rope_index(
    inputs_extra["input_ids"],
    mm_token_type_ids=inputs_extra["mm_token_type_ids"],
    image_grid_thw=inputs_extra["image_grid_thw"],
    attention_mask=inputs_extra["attention_mask"],
)
print(f"position_ids shape: {inputs_extra['position_ids'].shape}")

# --- 4. Eager forward (reference, no precomputed tensors) ---
print("\n=== Eager forward ===")
with torch.no_grad():
    out_eager = model(**inputs)
print(f"Logits shape: {out_eager.logits.shape}")

# --- 5. Compile the full model with fullgraph=True ---
print("\n=== torch.compile(model, fullgraph=True) ===")
model = torch.compile(model, fullgraph=True)

# --- 6. Without precomputed tensors: full model forward fails in vision ---
print("\nWithout precomputed tensors:")
try:
    with torch.no_grad():
        out = model(**inputs)
    print("  Unexpectedly succeeded!")
except Exception as e:
    print(f"  FAILED as expected: {type(e).__name__}")

# --- 7. With precomputed tensors: full model forward succeeds ---
print("\nWith precomputed tensors:")
with torch.no_grad():
    out_compiled = model(**inputs_extra)
print(f"  SUCCESS! Logits shape: {out_compiled.logits.shape}")

# --- 8. Verify compiled output matches eager ---
print("\n=== Verification: compiled vs eager ===")
diff = (out_eager.logits - out_compiled.logits).abs().max().item()
print(f"Max abs diff: {diff:.2e} {'OK' if diff < 1e-2 else 'MISMATCH'}")
  • Created top-level modeling_vision_utils.py with shared pure functions: get_vision_cu_seqlens, get_rotary_pos_ids, get_rotary_pos_ids_interleaved, get_window_index, get_pos_embed_indices
  • Moved audio precompute functions (chunk_and_pad_features, get_audio_cu_seqlens, get_valid_indices, get_pool_indices) into modular files directly
  • Simplifyied VisionRotaryEmbedding.forward to accept pos_ids tensor directly via broadcast multiply, eliminating data-dependent table creation
  • Made vision/audio encoder forwards accept optional precomputed tensors (cu_seqlens, rotary_pos_ids, window_index, embed_indices, etc.)
  • Used explicit naming: get_vision_cu_seqlens / get_audio_cu_seqlens

Models: qwen2_vl, qwen2_5_vl, qwen3_vl, qwen3_5, qwen3_vl_moe, qwen3_5_moe, qwen2_5_omni, qwen3_omni_moe, glm4v, glm4v_moe, glm_image, glm_ocr, ernie4_5_vl_moe, video_llama_3, mlcd, paddleocr_vl

What does this PR do?

Fixes # (issue)

Code Agent Policy

The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by
code agents. We are currently bottlenecked by our ability to review and respond to them. As a result,
we ask that new users do not submit pure code agent PRs at this time.
You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents
not to open any PRs or issues for the moment.

PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this
repeatedly or maliciously.

This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result,
this policy is likely to be updated regularly in the near future. For more information, please read CONTRIBUTING.md.

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

IlyasMoutawwakil and others added 5 commits April 13, 2026 10:44
- Create top-level `modeling_vision_utils.py` with shared pure functions:
  `get_vision_cu_seqlens`, `get_rotary_pos_ids`, `get_rotary_pos_ids_interleaved`,
  `get_window_index`, `get_pos_embed_indices`
- Move audio precompute functions (`chunk_and_pad_features`, `get_audio_cu_seqlens`,
  `get_valid_indices`, `get_pool_indices`) into modular files directly
- Simplify `VisionRotaryEmbedding.forward` to accept `pos_ids` tensor directly
  via broadcast multiply, eliminating data-dependent table creation
- Make vision/audio encoder forwards accept optional precomputed tensors
  (`cu_seqlens`, `rotary_pos_ids`, `window_index`, `embed_indices`, etc.)
- Use explicit naming: `get_vision_cu_seqlens` / `get_audio_cu_seqlens`

Models: qwen2_vl, qwen2_5_vl, qwen3_vl, qwen3_5, qwen3_vl_moe, qwen3_5_moe,
qwen2_5_omni, qwen3_omni_moe, glm4v, glm4v_moe, glm_image, glm_ocr,
ernie4_5_vl_moe, video_llama_3, mlcd, paddleocr_vl

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors multimodal (vision/audio) models to share “pure” tensor-building utilities and to optionally accept precomputed tensors (e.g., cu_seqlens / rotary pos ids), reducing duplicated logic across many model implementations and processors.

Changes:

  • Added src/transformers/modeling_vision_utils.py with standalone helpers (e.g., get_vision_cu_seqlens, get_rotary_pos_ids, get_window_index, get_pos_embed_indices) and updated multiple models/processors to use them.
  • Updated multiple vision encoders to accept optional precomputed tensors (cu_seqlens, rotary_pos_ids, window_index, embed_indices, etc.) and simplified rotary embedding computation to take pos_ids directly.
  • Refactored audio precompute logic into modular model files and added processor support for returning extra precomputed tensors via return_extra_tensors.

Reviewed changes

Copilot reviewed 37 out of 37 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
src/transformers/utils/auto_docstring.py Adds new documented processor/model kwargs for precomputed vision tensors.
src/transformers/models/video_llama_3/processing_video_llama_3.py Allows optionally returning precomputed vision tensors from the processor.
src/transformers/models/video_llama_3/modular_video_llama_3.py Switches vision rotary/cu_seqlens generation to shared helpers and adds optional precomputed inputs.
src/transformers/models/video_llama_3/modeling_video_llama_3.py Same as modular: uses shared helpers and updates rotary embedding forward API.
src/transformers/models/qwen3_vl/processing_qwen3_vl.py Adds optional return of precomputed cu_seqlens/rotary pos ids (interleaved variant).
src/transformers/models/qwen3_vl/modular_qwen3_vl.py Moves pos-embed/rotary/cu_seqlens computations to shared helpers; adds optional precomputed inputs.
src/transformers/models/qwen3_vl/modeling_qwen3_vl.py Same refactor as modular file (generated modeling).
src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py Same vision refactor for MoE variant.
src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py Moves audio chunking/cu_seqlens/valid index logic into pure helpers + forward accepts optional precomputes.
src/transformers/models/qwen3_5/modular_qwen3_5.py Refactors vision pos/rotary/cu_seqlens computations; adds optional precomputed inputs.
src/transformers/models/qwen3_5/modeling_qwen3_5.py Same vision refactor for generated modeling file.
src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py Same vision refactor for MoE variant.
src/transformers/models/qwen2_vl/processing_qwen2_vl.py Adds optional return of precomputed cu_seqlens/rotary pos ids from processor.
src/transformers/models/qwen2_vl/modeling_qwen2_vl.py Uses shared get_rotary_pos_ids / get_vision_cu_seqlens and accepts optional precomputed tensors.
src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py Adds optional return of precomputed cu_seqlens/rotary pos ids from processor.
src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py Refactors rotary/cu_seqlens/window indexing via shared helpers; adds optional precomputed inputs.
src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py Same refactor for generated modeling file.
src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py Moves audio chunking/indices/cu_seqlens/pooling computations into pure helper functions and accepts optional precomputes.
src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py Updates PaddleOCR vision path to use shared rotary/cu_seqlens helpers and renames args (grid_thw).
src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py Same as modular: shared helper usage and argument renames.
src/transformers/models/glm4v/processing_glm4v.py Adds optional return of precomputed cu_seqlens/rotary pos ids from processor.
src/transformers/models/glm4v/modular_glm4v.py Refactors vision rotary/cu_seqlens computations and video grid flattening logic.
src/transformers/models/glm4v/modeling_glm4v.py Same as modular file, plus updates to rotary embedding forward API.
src/transformers/models/glm4v_moe/modeling_glm4v_moe.py Same refactor for MoE variant.
src/transformers/models/glm46v/processing_glm46v.py Adds optional return of precomputed cu_seqlens/rotary pos ids from processor.
src/transformers/models/glm46v/modeling_glm46v.py Passes optional precomputed vision tensors through get_*_features and vision tower.
src/transformers/models/glm_ocr/modular_glm_ocr.py Refactors vision rotary/cu_seqlens computation to shared helpers.
src/transformers/models/glm_ocr/modeling_glm_ocr.py Same vision refactor for generated modeling file.
src/transformers/models/glm_image/modular_glm_image.py Refactors rotary pos ids and cu_seqlens to shared helpers; adds optional precomputed inputs.
src/transformers/models/glm_image/modeling_glm_image.py Same as modular file, plus updates to rotary embedding forward API.
src/transformers/models/esm/configuration_esm.py Moves rope_theta doc section to align with parameter ordering/docs.
src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py Refactors vision rotary/cu_seqlens to shared helpers and accepts optional precomputes.
src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py Same refactor for generated modeling file.
src/transformers/modeling_vision_utils.py New shared pure functions for vision tensor precomputation.
docs/source/en/model_doc/nomic_bert.md Updates NomicBERT paper link.

Comment thread src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py Outdated
Comment thread src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py Outdated
Comment thread src/transformers/vision_utils.py Outdated
Comment thread src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py Outdated
Comment thread src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors multimodal (vision/audio) models to allow passing precomputed, data-dependent tensors (e.g. cu_seqlens, rotary position IDs, window indices, position-embedding interpolation indices) and centralizes shared vision tensor construction into a new src/transformers/modeling_vision_utils.py.

Changes:

  • Add src/transformers/modeling_vision_utils.py with shared pure helper functions for vision precomputations (get_vision_cu_seqlens, rotary pos IDs, window indices, pos-embed interpolation indices).
  • Update many vision model/processor implementations to accept optional precomputed tensors and avoid rebuilding them inside forward.
  • Move/inline audio precompute helpers into relevant modular/model files and update docstrings/autodoc arg definitions accordingly.

Reviewed changes

Copilot reviewed 37 out of 37 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
src/transformers/utils/auto_docstring.py Adds autodoc entries for new optional precomputed image/video tensors.
src/transformers/modeling_vision_utils.py New shared pure functions for vision tensor precomputation (cu_seqlens, rotary pos IDs, window indices, pos-embed indices/weights).
src/transformers/models/video_llama_3/processing_video_llama_3.py Processor can optionally return extra precomputed vision tensors.
src/transformers/models/video_llama_3/modular_video_llama_3.py Vision forward accepts optional precomputed tensors; uses shared vision utils.
src/transformers/models/video_llama_3/modeling_video_llama_3.py Generated modeling updated to accept/use optional precomputed tensors.
src/transformers/models/qwen3_vl/processing_qwen3_vl.py Processor can optionally return extra precomputed vision tensors (incl. interleaved rotary IDs).
src/transformers/models/qwen3_vl/modular_qwen3_vl.py Vision path refactor to accept precomputed tensors; uses shared vision utils for pos-embed and rotary IDs.
src/transformers/models/qwen3_vl/modeling_qwen3_vl.py Generated modeling updated similarly (precomputed tensors + shared utils).
src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py Same precomputed-tensor refactor for MoE variant.
src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py Moves audio precompute helpers into the modular file; updates audio forward to accept precomputes.
src/transformers/models/qwen3_5/modular_qwen3_5.py Vision refactor to accept optional precomputed tensors; uses shared vision utils.
src/transformers/models/qwen3_5/modeling_qwen3_5.py Generated modeling updated similarly.
src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py Same precomputed-tensor refactor for MoE variant.
src/transformers/models/qwen2_vl/processing_qwen2_vl.py Processor can optionally return extra precomputed vision tensors.
src/transformers/models/qwen2_vl/modeling_qwen2_vl.py Vision forward accepts optional precomputed tensors; uses shared vision utils.
src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py Processor can optionally return extra precomputed vision tensors.
src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py Vision forward refactor: optional precomputes + shared get_window_index/rotary IDs/cu_seqlens.
src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py Generated modeling updated similarly.
src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py Moves audio precompute helpers into the modular file; updates audio forward to accept precomputes.
src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py Vision encoder refactor to accept grid_thw + optional precomputed rotary IDs / cu_seqlens.
src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py Generated modeling updated similarly.
src/transformers/models/glm4v/processing_glm4v.py Processor can optionally return extra precomputed vision tensors.
src/transformers/models/glm4v/modular_glm4v.py Vision forward accepts optional precomputed tensors; uses shared vision utils; minor tensor construction refactors.
src/transformers/models/glm4v/modeling_glm4v.py Generated modeling updated similarly.
src/transformers/models/glm4v_moe/modeling_glm4v_moe.py Same refactor for MoE variant.
src/transformers/models/glm46v/processing_glm46v.py Processor can optionally return extra precomputed vision tensors.
src/transformers/models/glm46v/modeling_glm46v.py Updates get_{image,video}_features signatures to accept precomputed tensors.
src/transformers/models/glm_ocr/modular_glm_ocr.py Vision forward accepts optional precomputed tensors; uses shared vision utils.
src/transformers/models/glm_ocr/modeling_glm_ocr.py Generated modeling updated similarly.
src/transformers/models/glm_image/modular_glm_image.py Vision forward accepts optional precomputed tensors; uses shared vision utils.
src/transformers/models/glm_image/modeling_glm_image.py Generated modeling updated similarly.
src/transformers/models/esm/configuration_esm.py Docstring reorders rope_theta description (docs-only change).
src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py Vision forward accepts optional precomputed tensors; uses shared vision utils.
src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py Generated modeling updated similarly.
docs/source/en/model_doc/nomic_bert.md Updates the paper link URL (docs-only change).

Comment thread src/transformers/vision_utils.py Outdated
Comment thread src/transformers/vision_utils.py Outdated
Comment thread src/transformers/models/qwen2_vl/modeling_qwen2_vl.py Outdated
Comment thread src/transformers/models/qwen3_vl/modeling_qwen3_vl.py Outdated
Comment thread src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py Outdated
Comment thread src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py Outdated
Comment on lines +773 to +776
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
video_cu_seqlens (`torch.IntTensor`, *optional*):
Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`).
Comment on lines +800 to +803
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
image_cu_seqlens (`torch.IntTensor`, *optional*):
Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`).
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
@IlyasMoutawwakil
Copy link
Copy Markdown
Member Author

run-slow: ernie4_5_vl_moe, esm, glm46v, glm4v, glm4v_moe, glm_image, glm_ocr, paddleocr_vl, qwen2_5_omni, qwen2_5_vl, qwen2_vl, qwen3_5, qwen3_5_moe, qwen3_omni_moe, qwen3_vl, qwen3_vl_moe

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/ernie4_5_vl_moe", "models/esm", "models/glm46v", "models/glm4v", "models/glm4v_moe", "models/glm_image", "models/glm_ocr", "models/paddleocr_vl", "models/qwen2_5_omni", "models/qwen2_5_vl", "models/qwen2_vl", "models/qwen3_5", "models/qwen3_5_moe", "models/qwen3_omni_moe", "models/qwen3_vl", "models/qwen3_vl_moe"]
quantizations: []

Copy link
Copy Markdown
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a few quick comments, randomly chose one model file to review

Comment thread src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py Outdated
Comment on lines 1609 to 1617
self,
hidden_states: torch.Tensor,
grid_thw: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
window_index: torch.Tensor | None = None,
cu_window_seqlens: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | BaseModelOutputWithPooling:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh cu_seqlens and position_ids are already in TransformersKwargs, no?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes true at this point, should i remove them from the forward and pop them from kwargs ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm however the cu_seqlens in TransfomersKwargs are these two

cu_seq_lens_q: torch.LongTensor | None
cu_seq_lens_k: torch.LongTensor | None

Comment on lines +1766 to +1767
video_cu_seqlens: torch.Tensor | None = None,
video_position_ids: torch.Tensor | None = None,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and for utilities, imo we don't yet need these args because they aren't returned by processor. Unless it is a req for export

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we remove them from here they can't be propagated through the model from forward, i can revert them but that means the model will still not be compliable end-to-end, only it's visual encoder will be.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah so we are currently consuming it all in kwargs via XXModel and explicitly write it out in vision-related models?

I think cu-seq-lens and positions are kinda oke to be consumed with existing kwargs, because they are assumed to represent FA-related arguments. Then we also have "cu_window_seqlens" indices in some models, which is actually the same cu-len used in FA

So imo we can consolidate these two, maybe similar to how attention mask is built? For ex, the model expects a dict of cu_seq_len with keys for layer types (full attn or window attn)

also cc @vasqu, wdyt?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be solved through a decorator maybe that filters/maps the kwargs based on modality. I still don't like the args being super visible because it should remain a power feature but I also see that it is needed for properly propagating. Wdyt about the decorator solution then?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I dont mind the TypedDict unpacking as imo these fit perfectly in "typical FA kwargs". Also doesnt look appealing to me when we have 4+ new args. What do you have in mind re decorators?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something along @fa_kwargs(modality="vision") and we make our internal mapping that goes through the kwargs and maps them to the correct naming. This decorator wouldn't apply here but where we actually need it then so down the line in the vision attention for example. At least that's the rough idea

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my understanding, the naming within vision model has to be without any prefixes. For ex, rn it accepts images and videos under same arg name (pixel_values)

But yeah, we might need to prefix it in a general VLM forward in subsequent PRs, if we wan to allow users prepare FA kwargs and pass it down the line

Comment thread src/transformers/vision_utils.py
@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 7d5206ca workflow commit (merge commit)
PR 5dcc3ea8 branch commit (from PR)
main 8fb7c7e5 base commit (on main)

Model CI Report

5 new failed tests from this PR 😭

  • qwen3_omni_moe:
    tests/models/qwen3_omni_moe/test_modeling_qwen3_omni_moe.py::Qwen3OmniModelIntegrationTest::test_small_model_integration_test_batch (❌ ⟹ ❌)

  • qwen3_vl_moe:
    tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py::Qwen3VLMoeIntegrationTest::test_small_model_integration_test (❌ ⟹ ❌)
    tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py::Qwen3VLMoeIntegrationTest::test_small_model_integration_test_batch_different_resolutions (❌ ⟹ ❌)
    tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py::Qwen3VLMoeIntegrationTest::test_small_model_integration_test_batch_wo_image (❌ ⟹ ❌)
    tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py::Qwen3VLMoeIntegrationTest::test_small_model_integration_test_expand (✅ ⟹ ❌)

@IlyasMoutawwakil
Copy link
Copy Markdown
Member Author

run-slow: ernie4_5_vl_moe, glm46v, glm4v, glm4v_moe, glm_image, glm_ocr, paddleocr_vl, qwen2_5_omni, qwen2_5_vl, qwen2_vl, qwen3_5, qwen3_5_moe, qwen3_omni_moe, qwen3_vl, qwen3_vl_moe, video_llama_3

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/ernie4_5_vl_moe", "models/glm46v", "models/glm4v", "models/glm4v_moe", "models/glm_image", "models/glm_ocr", "models/paddleocr_vl", "models/qwen2_5_omni", "models/qwen2_5_vl", "models/qwen2_vl", "models/qwen3_5", "models/qwen3_5_moe", "models/qwen3_omni_moe", "models/qwen3_vl", "models/qwen3_vl_moe", "models/video_llama_3"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 4c90da64 workflow commit (merge commit)
PR 4246a72d branch commit (from PR)
main ca72aa0a base commit (on main)

Model CI Report

9 new failed tests from this PR 😭

  • qwen3_omni_moe:
    tests/models/qwen3_omni_moe/test_modeling_qwen3_omni_moe.py::Qwen3OmniModelIntegrationTest::test_small_model_integration_test_multiturn (✅ ⟹ ❌)

  • qwen3_vl_moe:
    tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py::Qwen3VLMoeIntegrationTest::test_small_model_integration_test (❌ ⟹ ❌)
    tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py::Qwen3VLMoeIntegrationTest::test_small_model_integration_test_batch_different_resolutions (❌ ⟹ ❌)
    tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py::Qwen3VLMoeIntegrationTest::test_small_model_integration_test_batch_wo_image (❌ ⟹ ❌)
    tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py::Qwen3VLMoeIntegrationTest::test_small_model_integration_test_expand (✅ ⟹ ❌)

  • video_llama_3:
    tests/models/video_llama_3/test_modeling_video_llama_3.py::VideoLlama3IntegrationTest::test_small_model_integration_test (❌ ⟹ ❌)
    tests/models/video_llama_3/test_modeling_video_llama_3.py::VideoLlama3IntegrationTest::test_small_model_integration_test_batch (❌ ⟹ ❌)
    tests/models/video_llama_3/test_modeling_video_llama_3.py::VideoLlama3IntegrationTest::test_small_model_integration_test_batch_different_resolutions (❌ ⟹ ❌)
    tests/models/video_llama_3/test_modeling_video_llama_3.py::VideoLlama3IntegrationTest::test_small_model_integration_test_batch_wo_image (❌ ⟹ ❌)

@IlyasMoutawwakil
Copy link
Copy Markdown
Member Author

run-slow: qwen3_omni_moe, qwen3_vl_moe, video_llama_3

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/qwen3_omni_moe", "models/qwen3_vl_moe", "models/video_llama_3"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN e30b65b7 workflow commit (merge commit)
PR 62d901c6 branch commit (from PR)
main 74d576be base commit (on main)

Model CI Report

3 new failed tests from this PR 😭

  • qwen3_vl_moe:
    tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py::Qwen3VLMoeIntegrationTest::test_small_model_integration_test_batch_different_resolutions (❌ ⟹ ❌)

  • video_llama_3:
    tests/models/video_llama_3/test_modeling_video_llama_3.py::VideoLlama3IntegrationTest::test_small_model_integration_test_batch_different_resolutions (❌ ⟹ ❌)
    tests/models/video_llama_3/test_modeling_video_llama_3.py::VideoLlama3IntegrationTest::test_small_model_integration_test_batch_wo_image (❌ ⟹ ❌)

@IlyasMoutawwakil
Copy link
Copy Markdown
Member Author

run-slow: qwen3_vl_moe, video_llama_3

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: ernie4_5_vl_moe, glm46v, glm4v, glm4v_moe, glm_image, glm_ocr, paddleocr_vl, qwen2_5_omni, qwen2_5_vl, qwen2_vl, qwen3_5, qwen3_5_moe, qwen3_omni_moe, qwen3_vl, qwen3_vl_moe, video_llama_3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants