From 05d2d21fd6cffd06c79c38b223a8ac58aed22021 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 10:44:04 +0200 Subject: [PATCH 01/56] Extract pure vision/audio functions into standalone utilities - Create top-level `modeling_vision_utils.py` with shared pure functions: `get_vision_cu_seqlens`, `get_rotary_pos_ids`, `get_rotary_pos_ids_interleaved`, `get_window_index`, `get_pos_embed_indices` - Move audio precompute functions (`chunk_and_pad_features`, `get_audio_cu_seqlens`, `get_valid_indices`, `get_pool_indices`) into modular files directly - Simplify `VisionRotaryEmbedding.forward` to accept `pos_ids` tensor directly via broadcast multiply, eliminating data-dependent table creation - Make vision/audio encoder forwards accept optional precomputed tensors (`cu_seqlens`, `rotary_pos_ids`, `window_index`, `embed_indices`, etc.) - Use explicit naming: `get_vision_cu_seqlens` / `get_audio_cu_seqlens` Models: qwen2_vl, qwen2_5_vl, qwen3_vl, qwen3_5, qwen3_vl_moe, qwen3_5_moe, qwen2_5_omni, qwen3_omni_moe, glm4v, glm4v_moe, glm_image, glm_ocr, ernie4_5_vl_moe, video_llama_3, mlcd, paddleocr_vl Co-Authored-By: Claude Opus 4.6 (1M context) --- src/transformers/modeling_vision_utils.py | 235 +++++++++++ .../modeling_ernie4_5_vl_moe.py | 64 +-- .../modular_ernie4_5_vl_moe.py | 26 +- .../models/glm4v/modeling_glm4v.py | 96 ++--- .../models/glm4v/modular_glm4v.py | 90 ++--- .../models/glm4v_moe/modeling_glm4v_moe.py | 96 ++--- .../models/glm_image/modeling_glm_image.py | 68 ++-- .../models/glm_image/modular_glm_image.py | 57 +-- .../models/glm_ocr/modeling_glm_ocr.py | 83 ++-- .../models/glm_ocr/modular_glm_ocr.py | 36 +- src/transformers/models/mlcd/modeling_mlcd.py | 138 +++---- src/transformers/models/mlcd/modular_mlcd.py | 110 ++---- .../paddleocr_vl/modeling_paddleocr_vl.py | 10 +- .../paddleocr_vl/modular_paddleocr_vl.py | 6 +- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 350 ++++++++--------- .../qwen2_5_omni/modular_qwen2_5_omni.py | 273 ++++++++----- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 137 ++----- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 131 ++----- .../models/qwen2_vl/modeling_qwen2_vl.py | 61 +-- .../models/qwen3_5/modeling_qwen3_5.py | 156 ++------ .../models/qwen3_5/modular_qwen3_5.py | 45 ++- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 157 ++------ .../models/qwen3_5_moe/modular_qwen3_5_moe.py | 2 + .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 366 ++++++++---------- .../qwen3_omni_moe/modular_qwen3_omni_moe.py | 180 +++++++-- .../models/qwen3_vl/modeling_qwen3_vl.py | 153 ++------ .../models/qwen3_vl/modular_qwen3_vl.py | 147 ++----- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 152 ++------ .../video_llama_3/modeling_video_llama_3.py | 62 +-- .../video_llama_3/modular_video_llama_3.py | 61 +-- 30 files changed, 1521 insertions(+), 2027 deletions(-) create mode 100644 src/transformers/modeling_vision_utils.py diff --git a/src/transformers/modeling_vision_utils.py b/src/transformers/modeling_vision_utils.py new file mode 100644 index 000000000000..1fc7040f7553 --- /dev/null +++ b/src/transformers/modeling_vision_utils.py @@ -0,0 +1,235 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pure vision utility functions for computing data-dependent tensors. + +All functions are standalone (no model weights) and compute tensors from +``grid_thw`` + config scalars. They are used by vision encoders and can be +precomputed before ``torch.export`` tracing since they use untraceable ops +(``repeat_interleave``, ``.tolist()``, ``nonzero()``, loops). +""" + +import torch +import torch.nn.functional as F + + +def get_vision_cu_seqlens(grid_thw: torch.Tensor) -> torch.Tensor: + """Compute cumulative sequence lengths from vision grid info. + + Args: + grid_thw: ``(num_images_or_videos, 3)`` — temporal, height, width per entry. + + Returns: + ``cu_seqlens``: ``(total_patches + 1,)`` int32 cumulative sequence boundaries. + """ + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32 + ) + return F.pad(cu_seqlens, (1, 0), value=0) + + +def get_rotary_pos_ids(grid_thw: torch.Tensor, spatial_merge_size: int | torch.Tensor) -> torch.Tensor: + """Compute (row, col) position IDs for vision rotary embeddings. + + Args: + grid_thw: ``(num_images_or_videos, 3)`` + spatial_merge_size: merge block size — either a single ``int`` (same for all images) + or a ``(num_images_or_videos,)`` tensor (per-image). + + Returns: + ``pos_ids``: ``(total_tokens, 2)`` long — (row, col) position per token. + """ + device = grid_thw.device + if not isinstance(spatial_merge_size, torch.Tensor): + spatial_merge_size = torch.tensor([spatial_merge_size], device=device).expand(len(grid_thw)) + + pos_ids = [] + for (t, h, w), m in zip(grid_thw.tolist(), spatial_merge_size.tolist()): + t, h, w, m = int(t), int(h), int(w), int(m) + hpos_ids = torch.arange(h, device=device).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape(h // m, m, w // m, m).permute(0, 2, 1, 3).flatten() + + wpos_ids = torch.arange(w, device=device).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape(h // m, m, w // m, m).permute(0, 2, 1, 3).flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + return torch.cat(pos_ids, dim=0) + + +def get_rotary_pos_ids_interleaved(grid_thw: torch.Tensor, spatial_merge_size: int) -> torch.Tensor: + """Compute (row, col) position IDs for Qwen3-VL style vision rotary embeddings. + + Uses block-interleaved positions with intra-block offsets (different from the + Qwen2-VL variant which permutes whole rows/columns). + + Args: + grid_thw: ``(num_images_or_videos, 3)`` + spatial_merge_size: merge block size from vision config. + + Returns: + ``pos_ids``: ``(total_tokens, 2)`` long — (row, col) position per token. + """ + m = spatial_merge_size + device = grid_thw.device + total_tokens = int((grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw.tolist(): + num_frames, height, width = int(num_frames), int(height), int(width) + merged_h, merged_w = height // m, width // m + + block_rows = torch.arange(merged_h, device=device) + block_cols = torch.arange(merged_w, device=device) + intra_row = torch.arange(m, device=device) + intra_col = torch.arange(m, device=device) + + row_idx = ( + (block_rows[:, None, None, None] * m + intra_row[None, None, :, None]) + .expand(merged_h, merged_w, m, m) + .reshape(-1) + ) + col_idx = ( + (block_cols[None, :, None, None] * m + intra_col[None, None, None, :]) + .expand(merged_h, merged_w, m, m) + .reshape(-1) + ) + + coords = torch.stack((row_idx, col_idx), dim=-1) + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + return pos_ids + + +def get_window_index( + grid_thw: torch.Tensor, + spatial_merge_size: int, + window_size: int, + patch_size: int, + spatial_merge_unit: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute window attention indices for vision encoders with windowed attention. + + Args: + grid_thw: ``(num_images_or_videos, 3)`` + spatial_merge_size: merge block size from vision config. + window_size: window size from vision config. + patch_size: patch size from vision config. + spatial_merge_unit: ``spatial_merge_size ** 2``. + + Returns: + ``window_index``: ``(total_tokens,)`` long — reorder indices for windowed attention. + ``cu_window_seqlens``: ``(num_windows + 1,)`` int32 — cumulative window boundaries. + """ + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = window_size // spatial_merge_size // patch_size + + for grid_t, grid_h, grid_w in grid_thw.tolist(): + grid_t, grid_h, grid_w = int(grid_t), int(grid_h), int(grid_w) + llm_grid_h = grid_h // spatial_merge_size + llm_grid_w = grid_w // spatial_merge_size + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, num_windows_h, vit_merger_window_size, num_windows_w, vit_merger_window_size + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, num_windows_h * num_windows_w, vit_merger_window_size, vit_merger_window_size + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += grid_t * llm_grid_h * llm_grid_w + + window_index = torch.cat(window_index, dim=0) + cu_window_seqlens = torch.tensor(cu_window_seqlens, device=grid_thw.device, dtype=torch.int32) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + return window_index, cu_window_seqlens + + +def get_pos_embed_indices( + grid_thw: torch.Tensor, num_grid_per_side: int, spatial_merge_size: int +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute bilinear interpolation indices and weights for position embeddings. + + Args: + grid_thw: ``(num_images_or_videos, 3)`` + num_grid_per_side: ``int(num_position_embeddings ** 0.5)`` from vision config. + spatial_merge_size: merge block size from vision config. + + Returns: + ``embed_indices``: ``(4, total_thw)`` long — bilinear corner indices into pos_embed table. + ``bilinear_weights``: ``(4, total_thw)`` float — interpolation weights. + """ + N = num_grid_per_side + m = spatial_merge_size + device = grid_thw.device + + idx_parts: list[list[torch.Tensor]] = [[] for _ in range(4)] + weight_parts: list[list[torch.Tensor]] = [[] for _ in range(4)] + + for t, h, w in grid_thw.tolist(): + t, h, w = int(t), int(h), int(w) + + h_idxs = torch.linspace(0, N - 1, h, device=device) + w_idxs = torch.linspace(0, N - 1, w, device=device) + + h_floor = h_idxs.int() + w_floor = w_idxs.int() + h_ceil = (h_floor + 1).clamp(max=N - 1) + w_ceil = (w_floor + 1).clamp(max=N - 1) + + dh = h_idxs - h_floor + dw = w_idxs - w_floor + + bh_f = h_floor * N + bh_c = h_ceil * N + + raw_idx = [ + (bh_f[:, None] + w_floor[None, :]).flatten(), + (bh_f[:, None] + w_ceil[None, :]).flatten(), + (bh_c[:, None] + w_floor[None, :]).flatten(), + (bh_c[:, None] + w_ceil[None, :]).flatten(), + ] + raw_w = [ + ((1 - dh)[:, None] * (1 - dw)[None, :]).flatten(), + ((1 - dh)[:, None] * dw[None, :]).flatten(), + (dh[:, None] * (1 - dw)[None, :]).flatten(), + (dh[:, None] * dw[None, :]).flatten(), + ] + + h_idx = torch.arange(h, device=device).view(h // m, m) + w_idx = torch.arange(w, device=device).view(w // m, m) + reorder = (h_idx[:, :, None, None] * w + w_idx[None, None, :, :]).permute(0, 2, 1, 3).flatten().repeat(t) + + for i in range(4): + idx_parts[i].append(raw_idx[i][reorder]) + weight_parts[i].append(raw_w[i][reorder]) + + embed_indices = torch.stack([torch.cat(p) for p in idx_parts]) + bilinear_weights = torch.stack([torch.cat(p) for p in weight_parts]) + return embed_indices, bilinear_weights diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index d66e9fdc5dc7..c5580476b78d 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -37,6 +37,7 @@ from ...modeling_outputs import BaseModelOutputWithPooling, MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults @@ -855,10 +856,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: + return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) @auto_docstring @@ -893,58 +892,35 @@ def __init__(self, config) -> None: self.post_init() - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb - @merge_with_config_defaults @capture_outputs def forward( - self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" grid_thw (`torch.LongTensor` of shape `(num_images, 3)`): The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values. + cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). + rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). """ hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) for block in self.blocks: hidden_states = block( diff --git a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py index 42bbb44b70a5..985518cfedc5 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py @@ -20,7 +20,6 @@ import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from huggingface_hub.dataclasses import strict from torchvision.transforms.v2 import functional as tvF @@ -43,6 +42,7 @@ from ...modeling_outputs import BaseModelOutputWithPooling, MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import dynamic_rope_update from ...modeling_utils import PreTrainedModel +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import ImagesKwargs, Unpack from ...utils import ( TensorType, @@ -699,22 +699,24 @@ def get_device(self): @merge_with_config_defaults @capture_outputs def forward( - self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) for block in self.blocks: hidden_states = block( diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 6189d0f547ef..57e3481605e2 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -38,6 +38,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults @@ -110,10 +111,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: + return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Glm4vVisionPatchMerger(nn.Module): @@ -180,12 +179,11 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torc ) # Calculate target dimensions for each patch - target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) - target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) + num_tokens = embeddings.shape[0] + token_positions = torch.arange(num_tokens, device=embeddings.device) + seg_ids = (token_positions.unsqueeze(0) >= lengths.cumsum(0).unsqueeze(1)).sum(0) + target_h = image_shapes[seg_ids, 1].to(dtype=torch.float32) + target_w = image_shapes[seg_ids, 2].to(dtype=torch.float32) # Normalize coordinates to [-1, 1] range for grid_sample norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 @@ -727,72 +725,50 @@ def __init__(self, config) -> None: self.gradient_checkpointing = False self.post_init() - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb, pos_ids - @merge_with_config_defaults @capture_outputs @auto_docstring def forward( - self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). + rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) hidden_states = self.post_conv_layernorm(hidden_states) - rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) + + rotary_emb = self.rotary_pos_emb(rotary_pos_ids) + emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] hidden_states = self.embeddings( hidden_states, seqlens, grid_thw, - image_type_ids[:, 0].to(hidden_states.device), - image_type_ids[:, 1].to(hidden_states.device), + rotary_pos_ids[:, 0].to(hidden_states.device), + rotary_pos_ids[:, 1].to(hidden_states.device), ) for blk in self.blocks: @@ -1135,12 +1111,12 @@ def get_video_features( """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames - temp_frames_hw = [] - video_grid_thw_list = video_grid_thw.tolist() - for t, h, w in video_grid_thw_list: - repeated_row = torch.tensor([1, h, w]).unsqueeze(0).repeat(t, 1) - temp_frames_hw.append(repeated_row) - flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) + t = video_grid_thw[:, 0] + hw = video_grid_thw[:, 1:] + # repeat each (h,w) row `t` times + flattened_hw = torch.repeat_interleave(hw, t, dim=0) + prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) + flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) vision_outputs = self.visual( pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs ) diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 1ffd06532a8b..1b70144ccf64 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -33,6 +33,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_rope_utils import RopeParameters from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import ( @@ -315,12 +316,11 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torc ) # Calculate target dimensions for each patch - target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) - target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) + num_tokens = embeddings.shape[0] + token_positions = torch.arange(num_tokens, device=embeddings.device) + seg_ids = (token_positions.unsqueeze(0) >= lengths.cumsum(0).unsqueeze(1)).sum(0) + target_h = image_shapes[seg_ids, 1].to(dtype=torch.float32) + target_w = image_shapes[seg_ids, 2].to(dtype=torch.float32) # Normalize coordinates to [-1, 1] range for grid_sample norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 @@ -610,72 +610,50 @@ def __init__(self, config) -> None: self.gradient_checkpointing = False self.post_init() - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb, pos_ids - @merge_with_config_defaults @capture_outputs @auto_docstring def forward( - self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). + rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) hidden_states = self.post_conv_layernorm(hidden_states) - rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) + + rotary_emb = self.rotary_pos_emb(rotary_pos_ids) + emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] hidden_states = self.embeddings( hidden_states, seqlens, grid_thw, - image_type_ids[:, 0].to(hidden_states.device), - image_type_ids[:, 1].to(hidden_states.device), + rotary_pos_ids[:, 0].to(hidden_states.device), + rotary_pos_ids[:, 1].to(hidden_states.device), ) for blk in self.blocks: @@ -821,12 +799,12 @@ def get_video_features( """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames - temp_frames_hw = [] - video_grid_thw_list = video_grid_thw.tolist() - for t, h, w in video_grid_thw_list: - repeated_row = torch.tensor([1, h, w]).unsqueeze(0).repeat(t, 1) - temp_frames_hw.append(repeated_row) - flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) + t = video_grid_thw[:, 0] + hw = video_grid_thw[:, 1:] + # repeat each (h,w) row `t` times + flattened_hw = torch.repeat_interleave(hw, t, dim=0) + prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) + flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) vision_outputs = self.visual( pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs ) diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 363e4269f3a6..fa863b6e55ab 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -38,6 +38,7 @@ from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, torch_compilable_check from ...utils.generic import can_return_tuple, is_flash_attention_requested, maybe_autocast, merge_with_config_defaults @@ -469,10 +470,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: + return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) @use_kernel_forward_from_hub("RMSNorm") @@ -594,12 +593,11 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torc ) # Calculate target dimensions for each patch - target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) - target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) + num_tokens = embeddings.shape[0] + token_positions = torch.arange(num_tokens, device=embeddings.device) + seg_ids = (token_positions.unsqueeze(0) >= lengths.cumsum(0).unsqueeze(1)).sum(0) + target_h = image_shapes[seg_ids, 1].to(dtype=torch.float32) + target_w = image_shapes[seg_ids, 2].to(dtype=torch.float32) # Normalize coordinates to [-1, 1] range for grid_sample norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 @@ -791,72 +789,50 @@ def __init__(self, config) -> None: self.gradient_checkpointing = False self.post_init() - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb, pos_ids - @merge_with_config_defaults @capture_outputs @auto_docstring def forward( - self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). + rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) hidden_states = self.post_conv_layernorm(hidden_states) - rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) + + rotary_emb = self.rotary_pos_emb(rotary_pos_ids) + emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] hidden_states = self.embeddings( hidden_states, seqlens, grid_thw, - image_type_ids[:, 0].to(hidden_states.device), - image_type_ids[:, 1].to(hidden_states.device), + rotary_pos_ids[:, 0].to(hidden_states.device), + rotary_pos_ids[:, 1].to(hidden_states.device), ) for blk in self.blocks: @@ -1304,12 +1280,12 @@ def get_video_features( """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames - temp_frames_hw = [] - video_grid_thw_list = video_grid_thw.tolist() - for t, h, w in video_grid_thw_list: - repeated_row = torch.tensor([1, h, w]).unsqueeze(0).repeat(t, 1) - temp_frames_hw.append(repeated_row) - flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) + t = video_grid_thw[:, 0] + hw = video_grid_thw[:, 1:] + # repeat each (h,w) row `t` times + flattened_hw = torch.repeat_interleave(hw, t, dim=0) + prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) + flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) vision_outputs = self.visual( pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs ) diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py index 7215076ea8cf..b3d306e31a67 100644 --- a/src/transformers/models/glm_image/modeling_glm_image.py +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -36,6 +36,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import maybe_autocast, merge_with_config_defaults @@ -236,12 +237,11 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torc ) # Calculate target dimensions for each patch - target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) - target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) + num_tokens = embeddings.shape[0] + token_positions = torch.arange(num_tokens, device=embeddings.device) + seg_ids = (token_positions.unsqueeze(0) >= lengths.cumsum(0).unsqueeze(1)).sum(0) + target_h = image_shapes[seg_ids, 1].to(dtype=torch.float32) + target_w = image_shapes[seg_ids, 2].to(dtype=torch.float32) # Normalize coordinates to [-1, 1] range for grid_sample norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 @@ -589,63 +589,45 @@ def __init__(self, config: GlmImageVisionConfig) -> None: self.head_dim = head_dim self.post_init() - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - return pos_ids - @merge_with_config_defaults @capture_outputs @auto_docstring def forward( - self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + self, + pixel_values: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" pixel_values (`torch.Tensor` of shape `(total_patches, num_channels * patch_size * patch_size)`): Packed pixel values. grid_thw (`torch.Tensor` of shape `(num_images, 3)`): The temporal, height and width of feature shape of each image. + cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). + rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). Returns: `torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states. """ - hidden_states = self.patch_embed(pixel_values) - image_type_ids = self.rot_pos_emb(grid_thw) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) + + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] hidden_states = self.embeddings( hidden_states, seqlens, grid_thw, - image_type_ids[:, 0].to(hidden_states.device), - image_type_ids[:, 1].to(hidden_states.device), + rotary_pos_ids[:, 0].to(hidden_states.device), + rotary_pos_ids[:, 1].to(hidden_states.device), ) # Transformer blocks (no position_embeddings needed, already added above) diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index e72aede3da66..ba572d1485a8 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -28,6 +28,7 @@ from ...image_utils import ImageInput from ...modeling_outputs import BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import ImagesKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging @@ -428,63 +429,45 @@ def __init__(self, config: GlmImageVisionConfig): del self.downsample del self.post_layernorm - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - return pos_ids - @merge_with_config_defaults @capture_outputs @auto_docstring def forward( - self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + self, + pixel_values: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" pixel_values (`torch.Tensor` of shape `(total_patches, num_channels * patch_size * patch_size)`): Packed pixel values. grid_thw (`torch.Tensor` of shape `(num_images, 3)`): The temporal, height and width of feature shape of each image. + cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). + rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). Returns: `torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states. """ - hidden_states = self.patch_embed(pixel_values) - image_type_ids = self.rot_pos_emb(grid_thw) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) + + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] hidden_states = self.embeddings( hidden_states, seqlens, grid_thw, - image_type_ids[:, 0].to(hidden_states.device), - image_type_ids[:, 1].to(hidden_states.device), + rotary_pos_ids[:, 0].to(hidden_states.device), + rotary_pos_ids[:, 1].to(hidden_states.device), ) # Transformer blocks (no position_embeddings needed, already added above) diff --git a/src/transformers/models/glm_ocr/modeling_glm_ocr.py b/src/transformers/models/glm_ocr/modeling_glm_ocr.py index 30703d81c8c1..6515d5e21e2b 100644 --- a/src/transformers/models/glm_ocr/modeling_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modeling_glm_ocr.py @@ -25,7 +25,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from torch.nn import LayerNorm from ... import initialization as init @@ -39,6 +38,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults @@ -312,10 +312,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: + return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) @auto_docstring @@ -579,62 +577,41 @@ def __init__(self, config) -> None: self.gradient_checkpointing = False self.post_init() - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb, pos_ids - @merge_with_config_defaults @capture_outputs @auto_docstring - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: r""" hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). + rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) + + rotary_emb = self.rotary_pos_emb(rotary_pos_ids) + emb = torch.cat((rotary_emb, rotary_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) for blk in self.blocks: hidden_states = blk( @@ -1051,12 +1028,12 @@ def get_video_features( """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames - temp_frames_hw = [] - video_grid_thw_list = video_grid_thw.tolist() - for t, h, w in video_grid_thw_list: - repeated_row = torch.tensor([1, h, w]).unsqueeze(0).repeat(t, 1) - temp_frames_hw.append(repeated_row) - flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) + t = video_grid_thw[:, 0] + hw = video_grid_thw[:, 1:] + # repeat each (h,w) row `t` times + flattened_hw = torch.repeat_interleave(hw, t, dim=0) + prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) + flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) vision_outputs = self.visual( pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs ) diff --git a/src/transformers/models/glm_ocr/modular_glm_ocr.py b/src/transformers/models/glm_ocr/modular_glm_ocr.py index 2f71dded711d..5fc29c0c131a 100644 --- a/src/transformers/models/glm_ocr/modular_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modular_glm_ocr.py @@ -16,11 +16,11 @@ import torch import torch.nn as nn -import torch.nn.functional as F from huggingface_hub.dataclasses import strict from ...modeling_outputs import BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...utils import auto_docstring from ..glm4v.configuration_glm4v import Glm4vConfig, Glm4vTextConfig, Glm4vVisionConfig from ..glm4v.modeling_glm4v import ( @@ -247,30 +247,38 @@ def __init__(self, config) -> None: hidden_act=config.hidden_act, ) - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: r""" hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). + rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) + + rotary_emb = self.rotary_pos_emb(rotary_pos_ids) + emb = torch.cat((rotary_emb, rotary_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) for blk in self.blocks: hidden_states = blk( diff --git a/src/transformers/models/mlcd/modeling_mlcd.py b/src/transformers/models/mlcd/modeling_mlcd.py index efc0bb807d2d..a9e6ff9a930a 100644 --- a/src/transformers/models/mlcd/modeling_mlcd.py +++ b/src/transformers/models/mlcd/modeling_mlcd.py @@ -59,37 +59,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor: - """ - Calculate the Rotary Position Embedding (RoPE) for MLCDVisionModel based on the grid size. - - Args: - num_patches_height (int): Number of patches in the height dimension. - num_patches_width (int): Number of patches in the width dimension. - - Returns: - torch.Tensor: Rotary positional embeddings for the given grid size. - """ - # Generate position IDs for height and width dimensions - hpos_ids = ( - torch.arange(num_patches_height, device=self.inv_freq.device).unsqueeze(1).expand(-1, num_patches_width) - ) - wpos_ids = ( - torch.arange(num_patches_width, device=self.inv_freq.device).unsqueeze(0).expand(num_patches_height, -1) - ) - - # Flatten and stack the position IDs - pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1) - - # Generate the full rotary positional embeddings for the maximum grid size - max_grid_size = max(num_patches_height, num_patches_width) - seq = torch.arange(max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - rotary_pos_emb_full = torch.outer(seq, self.inv_freq) - - # Select and flatten the embeddings based on the position IDs - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - - return rotary_pos_emb + def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: + return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class MLCDVisionEmbeddings(nn.Module): @@ -399,7 +370,8 @@ def forward( @auto_docstring class MLCDPreTrainedModel(PreTrainedModel): config: MLCDVisionConfig - base_model_prefix = "mlcd" + base_model_prefix = "vision_model" + _no_split_modules = ["MLCDEncoderLayer"] supports_gradient_checkpointing = True accepts_loss_kwargs = False _supports_flash_attn = True @@ -434,7 +406,7 @@ def _init_weights(self, module): fc_std = (2 * module.config.hidden_size) ** -0.5 * factor init.normal_(module.fc1.weight, std=fc_std) init.normal_(module.fc2.weight, std=in_proj_std) - elif isinstance(module, MLCDVisionTransformer): + elif isinstance(module, MLCDVisionModel): factor = self.config.initializer_factor pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std) @@ -448,15 +420,19 @@ def _init_weights(self, module): init.copy_(module.inv_freq, inv_freq) -class MLCDVisionTransformer(MLCDPreTrainedModel): +@auto_docstring( + custom_intro=""" + The vision model from M_L_C_D without any head or projection on top. + """ +) +class MLCDVisionModel(MLCDPreTrainedModel): config: MLCDVisionConfig main_input_name = "pixel_values" input_modalities = ("image",) - _no_split_modules = ["MLCDEncoderLayer"] + _input_embed_layer = "patch_embedding" def __init__(self, config: MLCDVisionConfig): super().__init__(config) - self.config = config embed_dim = config.hidden_size self.embeddings = MLCDVisionEmbeddings(config) @@ -469,62 +445,6 @@ def __init__(self, config: MLCDVisionConfig): @merge_with_config_defaults @capture_outputs(tie_last_hidden_states=False) - @auto_docstring - def forward( - self, - pixel_values: torch.FloatTensor | None = None, - **kwargs: Unpack[TransformersKwargs], - ) -> tuple | BaseModelOutputWithPooling: - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - num_patches_height = pixel_values.shape[-2] // self.config.patch_size - num_patches_width = pixel_values.shape[-1] // self.config.patch_size - rotary_pos_emb = self.vision_rotary_embedding(num_patches_height, num_patches_width) - rotary_pos_emb = rotary_pos_emb.to(self.class_pos_emb.device) - rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - position_embeddings = (emb.cos(), emb.sin()) - - hidden_states = self.embeddings(pixel_values) - hidden_states = self.pre_layrnorm(hidden_states) - - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - position_embeddings=position_embeddings, - **kwargs, - ) - - last_hidden_state = encoder_outputs[0] - pooled_output = last_hidden_state[:, 0, :] - pooled_output = self.post_layernorm(pooled_output) - - return BaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - ) - - -@auto_docstring( - custom_intro=""" - The vision model from M_L_C_D without any head or projection on top. - """ -) -class MLCDVisionModel(MLCDPreTrainedModel): - config: MLCDVisionConfig - main_input_name = "pixel_values" - input_modalities = ("image",) - _no_split_modules = ["MLCDEncoderLayer"] - - def __init__(self, config: MLCDVisionConfig): - super().__init__(config) - self.vision_model = MLCDVisionTransformer(config) - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> nn.Module: - return self.vision_model.embeddings.patch_embedding - @auto_docstring def forward( self, @@ -555,10 +475,40 @@ def forward( >>> print(f"Number of attention layers: {len(outputs.attentions)}") >>> print(f"Attention shape: {outputs.attentions[0].shape}") ```""" - return self.vision_model( - pixel_values=pixel_values, + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + num_patches_height = pixel_values.shape[-2] // self.config.patch_size + num_patches_width = pixel_values.shape[-1] // self.config.patch_size + hpos_ids = ( + torch.arange(num_patches_height, device=pixel_values.device).unsqueeze(1).expand(-1, num_patches_width) + ) + wpos_ids = ( + torch.arange(num_patches_width, device=pixel_values.device).unsqueeze(0).expand(num_patches_height, -1) + ) + pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1) + rotary_pos_emb = self.vision_rotary_embedding(pos_ids) + rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + position_embeddings=position_embeddings, **kwargs, ) + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + ) + __all__ = ["MLCDPreTrainedModel", "MLCDVisionModel"] diff --git a/src/transformers/models/mlcd/modular_mlcd.py b/src/transformers/models/mlcd/modular_mlcd.py index 01641c7d5ce1..a1c291211730 100644 --- a/src/transformers/models/mlcd/modular_mlcd.py +++ b/src/transformers/models/mlcd/modular_mlcd.py @@ -30,7 +30,6 @@ CLIPEncoderLayer, CLIPVisionEmbeddings, CLIPVisionModel, - CLIPVisionTransformer, ) from ..llama.modeling_llama import eager_attention_forward from ..qwen2_vl.modeling_qwen2_vl import VisionRotaryEmbedding, apply_rotary_pos_emb_vision @@ -84,37 +83,7 @@ class MLCDMLP(CLIPMLP): class MLCDRotaryEmbedding(VisionRotaryEmbedding): - def forward(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor: - """ - Calculate the Rotary Position Embedding (RoPE) for MLCDVisionModel based on the grid size. - - Args: - num_patches_height (int): Number of patches in the height dimension. - num_patches_width (int): Number of patches in the width dimension. - - Returns: - torch.Tensor: Rotary positional embeddings for the given grid size. - """ - # Generate position IDs for height and width dimensions - hpos_ids = ( - torch.arange(num_patches_height, device=self.inv_freq.device).unsqueeze(1).expand(-1, num_patches_width) - ) - wpos_ids = ( - torch.arange(num_patches_width, device=self.inv_freq.device).unsqueeze(0).expand(num_patches_height, -1) - ) - - # Flatten and stack the position IDs - pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1) - - # Generate the full rotary positional embeddings for the maximum grid size - max_grid_size = max(num_patches_height, num_patches_width) - seq = torch.arange(max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - rotary_pos_emb_full = torch.outer(seq, self.inv_freq) - - # Select and flatten the embeddings based on the position IDs - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - - return rotary_pos_emb + pass class MLCDVisionEmbeddings(CLIPVisionEmbeddings): @@ -289,7 +258,8 @@ def forward( @auto_docstring class MLCDPreTrainedModel(PreTrainedModel): config: MLCDVisionConfig - base_model_prefix = "mlcd" + base_model_prefix = "vision_model" + _no_split_modules = ["MLCDEncoderLayer"] supports_gradient_checkpointing = True accepts_loss_kwargs = False _supports_flash_attn = True @@ -324,7 +294,7 @@ def _init_weights(self, module): fc_std = (2 * module.config.hidden_size) ** -0.5 * factor init.normal_(module.fc1.weight, std=fc_std) init.normal_(module.fc2.weight, std=in_proj_std) - elif isinstance(module, MLCDVisionTransformer): + elif isinstance(module, MLCDVisionModel): factor = self.config.initializer_factor pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std) @@ -338,48 +308,12 @@ def _init_weights(self, module): init.copy_(module.inv_freq, inv_freq) -class MLCDVisionTransformer(CLIPVisionTransformer): +class MLCDVisionModel(CLIPVisionModel): def __init__(self, config: MLCDVisionConfig): super().__init__(config) self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2) self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2)) - def forward( - self, - pixel_values: torch.FloatTensor | None = None, - **kwargs: Unpack[TransformersKwargs], - ) -> tuple | BaseModelOutputWithPooling: - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - num_patches_height = pixel_values.shape[-2] // self.config.patch_size - num_patches_width = pixel_values.shape[-1] // self.config.patch_size - rotary_pos_emb = self.vision_rotary_embedding(num_patches_height, num_patches_width) - rotary_pos_emb = rotary_pos_emb.to(self.class_pos_emb.device) - rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - position_embeddings = (emb.cos(), emb.sin()) - - hidden_states = self.embeddings(pixel_values) - hidden_states = self.pre_layrnorm(hidden_states) - - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - position_embeddings=position_embeddings, - **kwargs, - ) - - last_hidden_state = encoder_outputs[0] - pooled_output = last_hidden_state[:, 0, :] - pooled_output = self.post_layernorm(pooled_output) - - return BaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - ) - - -class MLCDVisionModel(CLIPVisionModel): def forward( self, pixel_values: torch.FloatTensor | None = None, @@ -409,11 +343,41 @@ def forward( >>> print(f"Number of attention layers: {len(outputs.attentions)}") >>> print(f"Attention shape: {outputs.attentions[0].shape}") ```""" - return self.vision_model( - pixel_values=pixel_values, + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + num_patches_height = pixel_values.shape[-2] // self.config.patch_size + num_patches_width = pixel_values.shape[-1] // self.config.patch_size + hpos_ids = ( + torch.arange(num_patches_height, device=pixel_values.device).unsqueeze(1).expand(-1, num_patches_width) + ) + wpos_ids = ( + torch.arange(num_patches_width, device=pixel_values.device).unsqueeze(0).expand(num_patches_height, -1) + ) + pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1) + rotary_pos_emb = self.vision_rotary_embedding(pos_ids) + rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + position_embeddings=position_embeddings, **kwargs, ) + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + ) + __all__ = [ "MLCDVisionConfig", diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 8ed3be0ad4be..9e3e83434958 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -98,10 +98,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: + return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class PaddleOCRRotaryEmbedding(nn.Module): @@ -857,9 +855,7 @@ def forward( height_position_ids = torch.concat(split_hids, dim=0) pids = torch.stack([height_position_ids, width_position_ids], dim=-1) - max_grid_size = pids.max() + 1 - rotary_embeddings_max_grid = self.rotary_pos_emb(max_grid_size) - rotary_embeddings = rotary_embeddings_max_grid[pids].flatten(1) + rotary_embeddings = self.rotary_pos_emb(pids) rotary_embeddings = rotary_embeddings.repeat(1, 2) position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin()) diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index 12d935978415..a3e1ec0dc4d4 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -533,7 +533,7 @@ class PaddleOCRVLConfig(Qwen2VLConfig): video_token_id: int = 100296 vision_start_token_id: int = 101305 vision_end_token_id: int = 101306 - tie_word_embeddings: int = True + tie_word_embeddings: bool = True class PaddleOCRProjector(nn.Module): @@ -823,9 +823,7 @@ def forward( height_position_ids = torch.concat(split_hids, dim=0) pids = torch.stack([height_position_ids, width_position_ids], dim=-1) - max_grid_size = pids.max() + 1 - rotary_embeddings_max_grid = self.rotary_pos_emb(max_grid_size) - rotary_embeddings = rotary_embeddings_max_grid[pids].flatten(1) + rotary_embeddings = self.rotary_pos_emb(pids) rotary_embeddings = rotary_embeddings.repeat(1, 2) position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin()) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index c8824b2f9730..2a517b346130 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -41,6 +41,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens, get_window_index from ...processing_utils import Unpack from ...utils import ( TransformersKwargs, @@ -716,6 +717,86 @@ def forward(self, seqlen: int): return self.positional_embedding[:seqlen, :] +def chunk_and_pad_features( + input_features: torch.Tensor, feature_lens: torch.Tensor, n_window: int +) -> tuple[torch.Tensor, torch.Tensor]: + """Split audio features into fixed-size chunks and pad to uniform length. + + Each audio sample is split into chunks of ``n_window * 2`` frames (the last + chunk may be shorter), then all chunks are right-padded to the longest chunk. + + Args: + input_features: ``(feature_dim, total_frames)`` concatenated audio features. + feature_lens: ``(batch_size,)`` per-sample frame counts. + n_window: half the target chunk size in frames. + + Returns: + ``padded_feature``: ``(num_chunks, feature_dim, max_chunk_len)`` padded chunks. + ``chunk_lengths``: ``(num_chunks,)`` actual length of each chunk before padding. + """ + chunk_num = torch.ceil(feature_lens / (n_window * 2)).long() + chunk_lengths = torch.full((chunk_num.sum(),), n_window * 2, dtype=torch.long, device=feature_lens.device) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (n_window * 2) + chunk_lengths = torch.where(chunk_lengths == 0, n_window * 2, chunk_lengths) + chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) + padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) + return padded_feature, chunk_lengths + + +def get_audio_cu_seqlens(chunk_lengths: torch.Tensor) -> torch.Tensor: + """Compute cumulative sequence lengths for audio attention from chunk lengths. + + Applies one stride-2 convolution length reduction, then returns cumulative + boundaries for flash-attention-style sequence packing. + + Args: + chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + + Returns: + ``(num_chunks + 1,)`` int32 cumulative sequence boundaries. + """ + after_conv1 = (chunk_lengths - 1) // 2 + 1 + return F.pad(after_conv1.cumsum(0), (1, 0), value=0).to(torch.int32) + + +def get_valid_indices(chunk_lengths: torch.Tensor) -> torch.Tensor: + """Compute flat indices of valid (non-padding) positions after one stride-2 conv. + + Args: + chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + + Returns: + ``(total_valid,)`` flat indices into the ``(num_chunks * max_len_after_conv)`` grid. + """ + after_conv1 = (chunk_lengths - 1) // 2 + 1 + max_len = after_conv1.max().item() + mask = torch.arange(max_len, device=chunk_lengths.device) < after_conv1.unsqueeze(1) + return mask.flatten().nonzero().squeeze(-1) + + +def get_pool_indices(feature_lens: torch.Tensor) -> torch.Tensor: + """Compute indices for stride-2 pooling over post-CNN audio features. + + Selects every other position (even indices) from each sample's post-CNN + features, accounting for two convolution stages and variable-length samples. + + Args: + feature_lens: ``(batch_size,)`` per-sample raw frame counts. + + Returns: + ``(total_pairs,)`` flat indices for stride-2 pooling across concatenated samples. + """ + after_conv1 = (feature_lens - 1) // 2 + 1 + after_conv2 = (after_conv1 - 2) // 2 + 1 + num_pairs = (after_conv2 - 1 + 1) // 2 + offsets = F.pad(after_conv2[:-1].cumsum(0), (1, 0), value=0) + pair_offsets = torch.repeat_interleave(offsets, num_pairs) + local_indices = torch.arange(num_pairs.sum(), device=feature_lens.device) + local_indices -= torch.repeat_interleave(F.pad(num_pairs[:-1].cumsum(0), (1, 0), value=0), num_pairs) + return pair_offsets + local_indices * 2 + + @auto_docstring( custom_intro=""" Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a @@ -765,60 +846,76 @@ def get_input_embeddings(self) -> nn.Module: def set_input_embeddings(self, value: nn.Module): self.conv1 = value - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if is_flash_attention_requested(self.config): - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - @merge_with_config_defaults @capture_outputs(tie_last_hidden_states=False) @auto_docstring - def forward(self, input_features, feature_lens=None, aftercnn_lens=None, **kwargs: Unpack[TransformersKwargs]): + def forward( + self, + input_features=None, + feature_lens=None, + padded_feature=None, + chunk_lengths=None, + valid_indices=None, + pool_indices=None, + cu_seqlens=None, + **kwargs: Unpack[TransformersKwargs], + ): r""" feature_lens (`torch.LongTensor` of shape `(batch_size,)`): mel length - aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`): - mel length after cnn + padded_feature (`torch.FloatTensor`, *optional*): + Precomputed padded audio chunks (from `chunk_and_pad_features`). + chunk_lengths (`torch.LongTensor`, *optional*): + Precomputed per-chunk lengths (from `chunk_and_pad_features`). + valid_indices (`torch.LongTensor`, *optional*): + Precomputed flat indices of valid post-CNN positions (from `get_valid_indices`). + pool_indices (`torch.LongTensor`, *optional*): + Precomputed pair indices for post-encoder average pooling (from `get_pool_indices`). + cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). """ - chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + if padded_feature is None: + padded_feature, chunk_lengths = chunk_and_pad_features(input_features, feature_lens, self.n_window) - chunk_lengths = torch.full((chunk_num.sum(),), self.n_window * 2, dtype=torch.long, device=feature_lens.device) - tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] - chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) - chunk_lengths = torch.where(chunk_lengths == 0, self.n_window * 2, chunk_lengths) + if valid_indices is None: + valid_indices = get_valid_indices(chunk_lengths) - chunk_list = input_features.split(chunk_lengths.tolist(), dim=1) - padded_feature, padded_mask, padded_mask_after_cnn = self.padded_and_mask_function( - chunk_list, chunk_lengths, padding_value=0, padding_side="right" + if pool_indices is None: + pool_indices = get_pool_indices(feature_lens) + + # Derive masks from chunk_lengths (traceable arithmetic + arange broadcasting) + padded_mask = ( + (torch.arange(padded_feature.shape[2], device=padded_feature.device) < chunk_lengths.unsqueeze(1)) + .unsqueeze(1) + .long() ) padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2) - padded_embed = padded_embed + self.positional_embedding.positional_embedding[ : padded_embed.shape[1], : ].unsqueeze(0).to(padded_embed.dtype) - hidden_states = padded_embed[padded_mask_after_cnn] - cu_seqlens = torch.cat( - ( - torch.zeros(1, device=padded_mask_after_cnn.device, dtype=torch.int32), - padded_mask_after_cnn.sum(1).cumsum(0), + hidden_states = torch.index_select(padded_embed.reshape(-1, padded_embed.shape[-1]), 0, valid_indices) + + if cu_seqlens is None: + cu_seqlens = get_audio_cu_seqlens(chunk_lengths) + + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention mask only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if is_flash_attention_requested(self.config): + attention_mask = None + else: + seq_idx = torch.arange(hidden_states.shape[0], device=hidden_states.device) + block_ids = torch.searchsorted(cu_seqlens[1:], seq_idx, right=True) + same_block = block_ids.unsqueeze(0) == block_ids.unsqueeze(1) + attention_mask = torch.full( + (hidden_states.shape[0], hidden_states.shape[0]), + torch.finfo(hidden_states.dtype).min, + dtype=hidden_states.dtype, + device=hidden_states.device, ) - ).to(torch.int32) - attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens) + attention_mask = attention_mask.masked_fill(same_block, 0.0).unsqueeze(0).unsqueeze(0) for encoder_layer in self.layers: layer_outputs = encoder_layer( @@ -829,54 +926,12 @@ def forward(self, input_features, feature_lens=None, aftercnn_lens=None, **kwarg ) hidden_states = layer_outputs[0] - hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) - token_audio_list = [] - for each_audio_states in hidden_states_list: - each_audio_states = self.avg_pooler(each_audio_states.transpose(0, 1)).transpose_(0, 1) - each_audio_states = self.ln_post(each_audio_states) - each_audio_states = self.proj(each_audio_states) - token_audio_list.append(each_audio_states) - token_audio = torch.cat(token_audio_list, dim=0) + # Post-process: average consecutive pairs per audio, then project + pooled = (hidden_states[pool_indices] + hidden_states[pool_indices + 1]) / 2 + pooled = self.ln_post(pooled) + token_audio = self.proj(pooled) return BaseModelOutputWithPooling(last_hidden_state=token_audio) - def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): - """ - Pads a sequence of tensors to their maximum length on indicated `padding_side`. - Then prepares a mask so that pad tokens are not attended to. - """ - max_len = tensor_len.max() - dim = tensor_list[0].shape[0] - padded_tensor = torch.full( - size=(len(tensor_list), dim, max_len), - fill_value=padding_value, - dtype=self.dtype, - device=tensor_list[0].device, - ) - - batch_mask = torch.zeros( - (len(tensor_len), max_len), - dtype=torch.long, - device=padded_tensor.device, - ) - for i, length in enumerate(tensor_len): - batch_mask[i, :length] = 1 - padded_tensor[i, :, :length] = tensor_list[i] - - feature_lens_after_cnn = (tensor_len - 1) // 2 + 1 - max_len_after_cnn = feature_lens_after_cnn.max() - batch_mask_after_cnn = torch.zeros( - (len(tensor_len), max_len_after_cnn), - dtype=torch.long, - device=padded_tensor.device, - ) - for i, length in enumerate(feature_lens_after_cnn): - batch_mask_after_cnn[i, :length] = 1 - return ( - padded_tensor, - batch_mask.unsqueeze(1), - batch_mask_after_cnn.bool(), - ) - # Ignore copy def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ @@ -1067,10 +1122,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: + return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Qwen2_5_VisionPatchEmbed(nn.Module): @@ -1152,81 +1205,17 @@ def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> self.post_init() - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw.tolist(): - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb - - def get_window_index(self, grid_thw): - window_index: list = [] - cu_window_seqlens: list = [0] - window_index_id = 0 - vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size - grid_thw_list = grid_thw.tolist() - - for grid_t, grid_h, grid_w in grid_thw_list: - llm_grid_h, llm_grid_w = ( - grid_h // self.spatial_merge_size, - grid_w // self.spatial_merge_size, - ) - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) - pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size - pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size - num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size - num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) - index_padded = index_padded.reshape( - grid_t, - num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size, - ) - index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, - num_windows_h * num_windows_w, - vit_merger_window_size, - vit_merger_window_size, - ) - seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) - index_padded = index_padded.reshape(-1) - index_new = index_padded[index_padded != -100] - window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] - cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) - window_index_id += grid_t * llm_grid_h * llm_grid_w - window_index = torch.cat(window_index, dim=0) - - return window_index, cu_window_seqlens - @merge_with_config_defaults @capture_outputs def forward( - self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + window_index: torch.Tensor | None = None, + cu_window_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ Args: @@ -1234,22 +1223,35 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). + cu_window_seqlens (`torch.Tensor`, *optional*): + Precomputed window cumulative sequence lengths (from `get_window_index`). + window_index (`torch.Tensor`, *optional*): + Precomputed window reordering index (from `get_window_index`). + rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) - window_index, cu_window_seqlens = self.get_window_index(grid_thw) - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - device=hidden_states.device, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) + + if window_index is None: + window_index, cu_window_seqlens = get_window_index( + grid_thw, self.spatial_merge_size, self.window_size, self.patch_size, self.spatial_merge_unit + ) seq_len, _ = hidden_states.size() + reverse_indices = torch.argsort(window_index) hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) @@ -1257,16 +1259,6 @@ def forward( rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - # Modification here for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: @@ -1282,7 +1274,6 @@ def forward( ) merged_hidden_states = self.merger(hidden_states) - reverse_indices = torch.argsort(window_index) merged_hidden_states = merged_hidden_states[reverse_indices, :] return BaseModelOutputWithPooling( @@ -1777,7 +1768,6 @@ def get_audio_features( audio_outputs = self.audio_tower( input_features, feature_lens=feature_lens, - aftercnn_lens=audio_feat_lengths, return_dict=True, **kwargs, ) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 6dd4e5727fc6..dec2b77cf06e 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -33,6 +33,7 @@ from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import RopeParameters from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens, get_window_index from ...processing_utils import Unpack from ...utils import ( TransformersKwargs, @@ -66,6 +67,86 @@ logger = logging.get_logger(__name__) +def chunk_and_pad_features( + input_features: torch.Tensor, feature_lens: torch.Tensor, n_window: int +) -> tuple[torch.Tensor, torch.Tensor]: + """Split audio features into fixed-size chunks and pad to uniform length. + + Each audio sample is split into chunks of ``n_window * 2`` frames (the last + chunk may be shorter), then all chunks are right-padded to the longest chunk. + + Args: + input_features: ``(feature_dim, total_frames)`` concatenated audio features. + feature_lens: ``(batch_size,)`` per-sample frame counts. + n_window: half the target chunk size in frames. + + Returns: + ``padded_feature``: ``(num_chunks, feature_dim, max_chunk_len)`` padded chunks. + ``chunk_lengths``: ``(num_chunks,)`` actual length of each chunk before padding. + """ + chunk_num = torch.ceil(feature_lens / (n_window * 2)).long() + chunk_lengths = torch.full((chunk_num.sum(),), n_window * 2, dtype=torch.long, device=feature_lens.device) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (n_window * 2) + chunk_lengths = torch.where(chunk_lengths == 0, n_window * 2, chunk_lengths) + chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) + padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) + return padded_feature, chunk_lengths + + +def get_audio_cu_seqlens(chunk_lengths: torch.Tensor) -> torch.Tensor: + """Compute cumulative sequence lengths for audio attention from chunk lengths. + + Applies one stride-2 convolution length reduction, then returns cumulative + boundaries for flash-attention-style sequence packing. + + Args: + chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + + Returns: + ``(num_chunks + 1,)`` int32 cumulative sequence boundaries. + """ + after_conv1 = (chunk_lengths - 1) // 2 + 1 + return F.pad(after_conv1.cumsum(0), (1, 0), value=0).to(torch.int32) + + +def get_valid_indices(chunk_lengths: torch.Tensor) -> torch.Tensor: + """Compute flat indices of valid (non-padding) positions after one stride-2 conv. + + Args: + chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + + Returns: + ``(total_valid,)`` flat indices into the ``(num_chunks * max_len_after_conv)`` grid. + """ + after_conv1 = (chunk_lengths - 1) // 2 + 1 + max_len = after_conv1.max().item() + mask = torch.arange(max_len, device=chunk_lengths.device) < after_conv1.unsqueeze(1) + return mask.flatten().nonzero().squeeze(-1) + + +def get_pool_indices(feature_lens: torch.Tensor) -> torch.Tensor: + """Compute indices for stride-2 pooling over post-CNN audio features. + + Selects every other position (even indices) from each sample's post-CNN + features, accounting for two convolution stages and variable-length samples. + + Args: + feature_lens: ``(batch_size,)`` per-sample raw frame counts. + + Returns: + ``(total_pairs,)`` flat indices for stride-2 pooling across concatenated samples. + """ + after_conv1 = (feature_lens - 1) // 2 + 1 + after_conv2 = (after_conv1 - 2) // 2 + 1 + num_pairs = (after_conv2 - 1 + 1) // 2 + offsets = F.pad(after_conv2[:-1].cumsum(0), (1, 0), value=0) + pair_offsets = torch.repeat_interleave(offsets, num_pairs) + local_indices = torch.arange(num_pairs.sum(), device=feature_lens.device) + local_indices -= torch.repeat_interleave(F.pad(num_pairs[:-1].cumsum(0), (1, 0), value=0), num_pairs) + return pair_offsets + local_indices * 2 + + @auto_docstring(checkpoint="Qwen/Qwen2.5-Omni-7B") @strict class Qwen2_5OmniVisionEncoderConfig(Qwen2_5_VLVisionConfig): @@ -1247,60 +1328,76 @@ def get_input_embeddings(self) -> nn.Module: def set_input_embeddings(self, value: nn.Module): self.conv1 = value - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if is_flash_attention_requested(self.config): - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - @merge_with_config_defaults @capture_outputs(tie_last_hidden_states=False) @auto_docstring - def forward(self, input_features, feature_lens=None, aftercnn_lens=None, **kwargs: Unpack[TransformersKwargs]): + def forward( + self, + input_features=None, + feature_lens=None, + padded_feature=None, + chunk_lengths=None, + valid_indices=None, + pool_indices=None, + cu_seqlens=None, + **kwargs: Unpack[TransformersKwargs], + ): r""" feature_lens (`torch.LongTensor` of shape `(batch_size,)`): mel length - aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`): - mel length after cnn + padded_feature (`torch.FloatTensor`, *optional*): + Precomputed padded audio chunks (from `chunk_and_pad_features`). + chunk_lengths (`torch.LongTensor`, *optional*): + Precomputed per-chunk lengths (from `chunk_and_pad_features`). + valid_indices (`torch.LongTensor`, *optional*): + Precomputed flat indices of valid post-CNN positions (from `get_valid_indices`). + pool_indices (`torch.LongTensor`, *optional*): + Precomputed pair indices for post-encoder average pooling (from `get_pool_indices`). + cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). """ - chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + if padded_feature is None: + padded_feature, chunk_lengths = chunk_and_pad_features(input_features, feature_lens, self.n_window) + + if valid_indices is None: + valid_indices = get_valid_indices(chunk_lengths) - chunk_lengths = torch.full((chunk_num.sum(),), self.n_window * 2, dtype=torch.long, device=feature_lens.device) - tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] - chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) - chunk_lengths = torch.where(chunk_lengths == 0, self.n_window * 2, chunk_lengths) + if pool_indices is None: + pool_indices = get_pool_indices(feature_lens) - chunk_list = input_features.split(chunk_lengths.tolist(), dim=1) - padded_feature, padded_mask, padded_mask_after_cnn = self.padded_and_mask_function( - chunk_list, chunk_lengths, padding_value=0, padding_side="right" + # Derive masks from chunk_lengths (traceable arithmetic + arange broadcasting) + padded_mask = ( + (torch.arange(padded_feature.shape[2], device=padded_feature.device) < chunk_lengths.unsqueeze(1)) + .unsqueeze(1) + .long() ) padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2) - padded_embed = padded_embed + self.positional_embedding.positional_embedding[ : padded_embed.shape[1], : ].unsqueeze(0).to(padded_embed.dtype) - hidden_states = padded_embed[padded_mask_after_cnn] - cu_seqlens = torch.cat( - ( - torch.zeros(1, device=padded_mask_after_cnn.device, dtype=torch.int32), - padded_mask_after_cnn.sum(1).cumsum(0), + hidden_states = torch.index_select(padded_embed.reshape(-1, padded_embed.shape[-1]), 0, valid_indices) + + if cu_seqlens is None: + cu_seqlens = get_audio_cu_seqlens(chunk_lengths) + + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention mask only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if is_flash_attention_requested(self.config): + attention_mask = None + else: + seq_idx = torch.arange(hidden_states.shape[0], device=hidden_states.device) + block_ids = torch.searchsorted(cu_seqlens[1:], seq_idx, right=True) + same_block = block_ids.unsqueeze(0) == block_ids.unsqueeze(1) + attention_mask = torch.full( + (hidden_states.shape[0], hidden_states.shape[0]), + torch.finfo(hidden_states.dtype).min, + dtype=hidden_states.dtype, + device=hidden_states.device, ) - ).to(torch.int32) - attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens) + attention_mask = attention_mask.masked_fill(same_block, 0.0).unsqueeze(0).unsqueeze(0) for encoder_layer in self.layers: layer_outputs = encoder_layer( @@ -1311,54 +1408,12 @@ def forward(self, input_features, feature_lens=None, aftercnn_lens=None, **kwarg ) hidden_states = layer_outputs[0] - hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) - token_audio_list = [] - for each_audio_states in hidden_states_list: - each_audio_states = self.avg_pooler(each_audio_states.transpose(0, 1)).transpose_(0, 1) - each_audio_states = self.ln_post(each_audio_states) - each_audio_states = self.proj(each_audio_states) - token_audio_list.append(each_audio_states) - token_audio = torch.cat(token_audio_list, dim=0) + # Post-process: average consecutive pairs per audio, then project + pooled = (hidden_states[pool_indices] + hidden_states[pool_indices + 1]) / 2 + pooled = self.ln_post(pooled) + token_audio = self.proj(pooled) return BaseModelOutputWithPooling(last_hidden_state=token_audio) - def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): - """ - Pads a sequence of tensors to their maximum length on indicated `padding_side`. - Then prepares a mask so that pad tokens are not attended to. - """ - max_len = tensor_len.max() - dim = tensor_list[0].shape[0] - padded_tensor = torch.full( - size=(len(tensor_list), dim, max_len), - fill_value=padding_value, - dtype=self.dtype, - device=tensor_list[0].device, - ) - - batch_mask = torch.zeros( - (len(tensor_len), max_len), - dtype=torch.long, - device=padded_tensor.device, - ) - for i, length in enumerate(tensor_len): - batch_mask[i, :length] = 1 - padded_tensor[i, :, :length] = tensor_list[i] - - feature_lens_after_cnn = (tensor_len - 1) // 2 + 1 - max_len_after_cnn = feature_lens_after_cnn.max() - batch_mask_after_cnn = torch.zeros( - (len(tensor_len), max_len_after_cnn), - dtype=torch.long, - device=padded_tensor.device, - ) - for i, length in enumerate(feature_lens_after_cnn): - batch_mask_after_cnn[i, :length] = 1 - return ( - padded_tensor, - batch_mask.unsqueeze(1), - batch_mask_after_cnn.bool(), - ) - # Ignore copy def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ @@ -1508,7 +1563,14 @@ def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> @merge_with_config_defaults @capture_outputs def forward( - self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + window_index: torch.Tensor | None = None, + cu_window_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ Args: @@ -1516,22 +1578,35 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). + cu_window_seqlens (`torch.Tensor`, *optional*): + Precomputed window cumulative sequence lengths (from `get_window_index`). + window_index (`torch.Tensor`, *optional*): + Precomputed window reordering index (from `get_window_index`). + rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) - window_index, cu_window_seqlens = self.get_window_index(grid_thw) - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - device=hidden_states.device, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) + + if window_index is None: + window_index, cu_window_seqlens = get_window_index( + grid_thw, self.spatial_merge_size, self.window_size, self.patch_size, self.spatial_merge_unit + ) seq_len, _ = hidden_states.size() + reverse_indices = torch.argsort(window_index) hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) @@ -1539,16 +1614,6 @@ def forward( rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - # Modification here for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: @@ -1564,7 +1629,6 @@ def forward( ) merged_hidden_states = self.merger(hidden_states) - reverse_indices = torch.argsort(window_index) merged_hidden_states = merged_hidden_states[reverse_indices, :] return BaseModelOutputWithPooling( @@ -1714,7 +1778,6 @@ def get_audio_features( audio_outputs = self.audio_tower( input_features, feature_lens=feature_lens, - aftercnn_lens=audio_feat_lengths, return_dict=True, **kwargs, ) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index f666d5f760f6..96873304c832 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -30,7 +30,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from ... import initialization as init from ...activations import ACT2FN @@ -43,6 +42,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens, get_window_index from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults @@ -124,10 +124,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: + return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Qwen2_5_VLPatchMerger(nn.Module): @@ -379,81 +377,17 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw.tolist(): - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb - - def get_window_index(self, grid_thw): - window_index: list = [] - cu_window_seqlens: list = [0] - window_index_id = 0 - vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size - grid_thw_list = grid_thw.tolist() - - for grid_t, grid_h, grid_w in grid_thw_list: - llm_grid_h, llm_grid_w = ( - grid_h // self.spatial_merge_size, - grid_w // self.spatial_merge_size, - ) - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) - pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size - pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size - num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size - num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) - index_padded = index_padded.reshape( - grid_t, - num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size, - ) - index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, - num_windows_h * num_windows_w, - vit_merger_window_size, - vit_merger_window_size, - ) - seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) - index_padded = index_padded.reshape(-1) - index_new = index_padded[index_padded != -100] - window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] - cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) - window_index_id += grid_t * llm_grid_h * llm_grid_w - window_index = torch.cat(window_index, dim=0) - - return window_index, cu_window_seqlens - @merge_with_config_defaults @capture_outputs def forward( - self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + window_index: torch.Tensor | None = None, + cu_window_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ Args: @@ -461,19 +395,34 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from `compute_cu_seqlens`). + window_index (`torch.Tensor`, *optional*): + Precomputed window reordering index (from `compute_window_index`). + cu_window_seqlens (`torch.Tensor`, *optional*): + Precomputed window cumulative sequence lengths (from `compute_window_index`). + rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) - window_index, cu_window_seqlens = self.get_window_index(grid_thw) - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - device=hidden_states.device, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) + + if window_index is None: + window_index, cu_window_seqlens = get_window_index( + grid_thw, self.spatial_merge_size, self.window_size, self.patch_size, self.spatial_merge_unit + ) + + reverse_indices = torch.argsort(window_index) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) @@ -485,16 +434,6 @@ def forward( emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens @@ -509,7 +448,6 @@ def forward( ) merged_hidden_states = self.merger(hidden_states) - reverse_indices = torch.argsort(window_index) merged_hidden_states = merged_hidden_states[reverse_indices, :] return BaseModelOutputWithPooling( @@ -1117,7 +1055,12 @@ def get_rope_index( # image == 1, video == 2 else: grid_thw = next(grid_iters[modality_type]) - time_interval = tokens_per_second * int(next(second_per_grid_ts)) + # Only apply temporal scaling for videos; still images have no + # temporal dimension to space out (fixes #45325). + if modality_type == 2: + time_interval = tokens_per_second * int(next(second_per_grid_ts)) + else: + time_interval = 1 vision_position_ids = self.get_vision_position_ids( current_pos, grid_thw, 1, spatial_merge_size, time_interval, device=input_ids.device ) diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 8a103cefd225..406e75a10049 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -22,7 +22,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from huggingface_hub.dataclasses import strict from ... import initialization as init @@ -34,6 +33,7 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens, get_window_index from ...processing_utils import MultiModalData, ProcessingKwargs, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, can_return_tuple, logging @@ -218,81 +218,17 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw.tolist(): - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb - - def get_window_index(self, grid_thw): - window_index: list = [] - cu_window_seqlens: list = [0] - window_index_id = 0 - vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size - grid_thw_list = grid_thw.tolist() - - for grid_t, grid_h, grid_w in grid_thw_list: - llm_grid_h, llm_grid_w = ( - grid_h // self.spatial_merge_size, - grid_w // self.spatial_merge_size, - ) - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) - pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size - pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size - num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size - num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) - index_padded = index_padded.reshape( - grid_t, - num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size, - ) - index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, - num_windows_h * num_windows_w, - vit_merger_window_size, - vit_merger_window_size, - ) - seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) - index_padded = index_padded.reshape(-1) - index_new = index_padded[index_padded != -100] - window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] - cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) - window_index_id += grid_t * llm_grid_h * llm_grid_w - window_index = torch.cat(window_index, dim=0) - - return window_index, cu_window_seqlens - @merge_with_config_defaults @capture_outputs def forward( - self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + window_index: torch.Tensor | None = None, + cu_window_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ Args: @@ -300,19 +236,34 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from `compute_cu_seqlens`). + window_index (`torch.Tensor`, *optional*): + Precomputed window reordering index (from `compute_window_index`). + cu_window_seqlens (`torch.Tensor`, *optional*): + Precomputed window cumulative sequence lengths (from `compute_window_index`). + rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) - window_index, cu_window_seqlens = self.get_window_index(grid_thw) - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - device=hidden_states.device, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) + + if window_index is None: + window_index, cu_window_seqlens = get_window_index( + grid_thw, self.spatial_merge_size, self.window_size, self.patch_size, self.spatial_merge_unit + ) + + reverse_indices = torch.argsort(window_index) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) @@ -324,16 +275,6 @@ def forward( emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens @@ -348,7 +289,6 @@ def forward( ) merged_hidden_states = self.merger(hidden_states) - reverse_indices = torch.argsort(window_index) merged_hidden_states = merged_hidden_states[reverse_indices, :] return BaseModelOutputWithPooling( @@ -468,7 +408,12 @@ def get_rope_index( # image == 1, video == 2 else: grid_thw = next(grid_iters[modality_type]) - time_interval = tokens_per_second * int(next(second_per_grid_ts)) + # Only apply temporal scaling for videos; still images have no + # temporal dimension to space out (fixes #45325). + if modality_type == 2: + time_interval = tokens_per_second * int(next(second_per_grid_ts)) + else: + time_interval = 1 vision_position_ids = self.get_vision_position_ids( current_pos, grid_thw, 1, spatial_merge_size, time_interval, device=input_ids.device ) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 6dc8755528d7..576ff9b206ff 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -25,7 +25,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from torch.nn import LayerNorm from ... import initialization as init @@ -39,6 +38,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import ( TransformersKwargs, @@ -278,10 +278,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: + return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class PatchEmbed(nn.Module): @@ -722,35 +720,6 @@ def get_dtype(self) -> torch.dtype: def get_device(self) -> torch.device: return self.blocks[0].mlp.fc2.weight.device - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb - @merge_with_config_defaults @capture_outputs @auto_docstring @@ -758,26 +727,30 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: r""" grid_thw (`torch.LongTensor` of shape `(num_images, 3)`): The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values. + cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). + rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). """ hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) for blk in self.blocks: hidden_states = blk( diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index eba3eec02fdd..9ab2730c9e88 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -43,6 +43,8 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_pos_embed_indices, get_vision_cu_seqlens +from ...modeling_vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults @@ -76,10 +78,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: + return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Qwen3_5TextRotaryEmbedding(nn.Module): @@ -293,7 +293,7 @@ def torch_chunk_gated_delta_rule( # for each chunk for i in range(0, total_sequence_length // chunk_size): q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] - attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + attn = q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i] v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state v_new = v_i - v_prime attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state @@ -1027,128 +1027,52 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - merge_size = self.spatial_merge_size - grid_thw_list = grid_thw.tolist() - - max_hw = max(max(h, w) for _, h, w in grid_thw_list) - freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) - device = freq_table.device - - total_tokens = sum(t * h * w for t, h, w in grid_thw_list) - pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) - - offset = 0 - for num_frames, height, width in grid_thw_list: - merged_h, merged_w = height // merge_size, width // merge_size - - block_rows = torch.arange(merged_h, device=device) # block row indices - block_cols = torch.arange(merged_w, device=device) # block col indices - intra_row = torch.arange(merge_size, device=device) # intra-block row offsets - intra_col = torch.arange(merge_size, device=device) # intra-block col offsets - - # Compute full-resolution positions - row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] - col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] - - row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - - coords = torch.stack((row_idx, col_idx), dim=-1) - - if num_frames > 1: - coords = coords.repeat(num_frames, 1) - - num_tokens = coords.shape[0] - pos_ids[offset : offset + num_tokens] = coords - offset += num_tokens - - embeddings = freq_table[pos_ids] # lookup rotary embeddings - embeddings = embeddings.flatten(1) - return embeddings - - def fast_pos_embed_interpolate(self, grid_thw): - grid_thw_list = grid_thw.tolist() - grid_ts = [row[0] for row in grid_thw_list] - grid_hs = [row[1] for row in grid_thw_list] - grid_ws = [row[2] for row in grid_thw_list] - device = self.pos_embed.weight.device - - idx_list = [[] for _ in range(4)] - weight_list = [[] for _ in range(4)] - - for t, h, w in grid_thw_list: - h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) - w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) - - h_idxs_floor = h_idxs.int() - w_idxs_floor = w_idxs.int() - h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - - dh = h_idxs - h_idxs_floor - dw = w_idxs - w_idxs_floor - - base_h = h_idxs_floor * self.num_grid_per_side - base_h_ceil = h_idxs_ceil * self.num_grid_per_side - - indices = [ - (base_h[None].T + w_idxs_floor[None]).flatten(), - (base_h[None].T + w_idxs_ceil[None]).flatten(), - (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), - (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), - ] - - weights = [ - ((1 - dh)[None].T * (1 - dw)[None]).flatten(), - ((1 - dh)[None].T * dw[None]).flatten(), - (dh[None].T * (1 - dw)[None]).flatten(), - (dh[None].T * dw[None]).flatten(), - ] - - for i in range(4): - idx_list[i].extend(indices[i].tolist()) - weight_list[i].extend(weights[i].tolist()) - - idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) - weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) - pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] - patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] - - patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) - - patch_pos_embeds_permute = [] - merge_size = self.config.spatial_merge_size - for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): - pos_embed = pos_embed.repeat(t, 1) - pos_embed = ( - pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) - .permute(0, 1, 3, 2, 4, 5) - .flatten(0, 4) - ) - patch_pos_embeds_permute.append(pos_embed) - patch_pos_embeds = torch.cat(patch_pos_embeds_permute) - return patch_pos_embeds - @merge_with_config_defaults @capture_outputs - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + embed_indices: torch.Tensor | None = None, + bilinear_weights: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: """ Args: hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from `compute_cu_seqlens`). + rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Bilinear corner indices into the position embedding table (from `compute_pos_embed_indices`). + bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Interpolation weights for the four bilinear corners (from `compute_pos_embed_indices`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + if embed_indices is None or bilinear_weights is None: + embed_indices, bilinear_weights = get_pos_embed_indices( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds - rotary_pos_emb = self.rot_pos_emb(grid_thw) + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) @@ -1156,16 +1080,6 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - for blk in self.blocks: hidden_states = blk( hidden_states, diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index 8fddbc6115c1..ab1304871979 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -26,6 +26,8 @@ from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel +from ...modeling_vision_utils import get_pos_embed_indices, get_vision_cu_seqlens +from ...modeling_vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import merge_with_config_defaults @@ -420,23 +422,50 @@ def __init__(self, config, *inputs, **kwargs) -> None: @merge_with_config_defaults @capture_outputs - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + embed_indices: torch.Tensor | None = None, + bilinear_weights: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: """ Args: hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from `compute_cu_seqlens`). + rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Bilinear corner indices into the position embedding table (from `compute_pos_embed_indices`). + bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Interpolation weights for the four bilinear corners (from `compute_pos_embed_indices`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + if embed_indices is None or bilinear_weights is None: + embed_indices, bilinear_weights = get_pos_embed_indices( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds - rotary_pos_emb = self.rot_pos_emb(grid_thw) + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) @@ -444,16 +473,6 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - for blk in self.blocks: hidden_states = blk( hidden_states, diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index be4501d34903..c0e72f6d1076 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -44,6 +44,8 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_pos_embed_indices, get_vision_cu_seqlens +from ...modeling_vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults @@ -77,10 +79,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: + return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Qwen3_5MoeTextRotaryEmbedding(nn.Module): @@ -294,7 +294,7 @@ def torch_chunk_gated_delta_rule( # for each chunk for i in range(0, total_sequence_length // chunk_size): q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] - attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + attn = q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i] v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state v_new = v_i - v_prime attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state @@ -1120,128 +1120,52 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - merge_size = self.spatial_merge_size - grid_thw_list = grid_thw.tolist() - - max_hw = max(max(h, w) for _, h, w in grid_thw_list) - freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) - device = freq_table.device - - total_tokens = sum(t * h * w for t, h, w in grid_thw_list) - pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) - - offset = 0 - for num_frames, height, width in grid_thw_list: - merged_h, merged_w = height // merge_size, width // merge_size - - block_rows = torch.arange(merged_h, device=device) # block row indices - block_cols = torch.arange(merged_w, device=device) # block col indices - intra_row = torch.arange(merge_size, device=device) # intra-block row offsets - intra_col = torch.arange(merge_size, device=device) # intra-block col offsets - - # Compute full-resolution positions - row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] - col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] - - row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - - coords = torch.stack((row_idx, col_idx), dim=-1) - - if num_frames > 1: - coords = coords.repeat(num_frames, 1) - - num_tokens = coords.shape[0] - pos_ids[offset : offset + num_tokens] = coords - offset += num_tokens - - embeddings = freq_table[pos_ids] # lookup rotary embeddings - embeddings = embeddings.flatten(1) - return embeddings - - def fast_pos_embed_interpolate(self, grid_thw): - grid_thw_list = grid_thw.tolist() - grid_ts = [row[0] for row in grid_thw_list] - grid_hs = [row[1] for row in grid_thw_list] - grid_ws = [row[2] for row in grid_thw_list] - device = self.pos_embed.weight.device - - idx_list = [[] for _ in range(4)] - weight_list = [[] for _ in range(4)] - - for t, h, w in grid_thw_list: - h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) - w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) - - h_idxs_floor = h_idxs.int() - w_idxs_floor = w_idxs.int() - h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - - dh = h_idxs - h_idxs_floor - dw = w_idxs - w_idxs_floor - - base_h = h_idxs_floor * self.num_grid_per_side - base_h_ceil = h_idxs_ceil * self.num_grid_per_side - - indices = [ - (base_h[None].T + w_idxs_floor[None]).flatten(), - (base_h[None].T + w_idxs_ceil[None]).flatten(), - (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), - (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), - ] - - weights = [ - ((1 - dh)[None].T * (1 - dw)[None]).flatten(), - ((1 - dh)[None].T * dw[None]).flatten(), - (dh[None].T * (1 - dw)[None]).flatten(), - (dh[None].T * dw[None]).flatten(), - ] - - for i in range(4): - idx_list[i].extend(indices[i].tolist()) - weight_list[i].extend(weights[i].tolist()) - - idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) - weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) - pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] - patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] - - patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) - - patch_pos_embeds_permute = [] - merge_size = self.config.spatial_merge_size - for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): - pos_embed = pos_embed.repeat(t, 1) - pos_embed = ( - pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) - .permute(0, 1, 3, 2, 4, 5) - .flatten(0, 4) - ) - patch_pos_embeds_permute.append(pos_embed) - patch_pos_embeds = torch.cat(patch_pos_embeds_permute) - return patch_pos_embeds - @merge_with_config_defaults @capture_outputs - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + embed_indices: torch.Tensor | None = None, + bilinear_weights: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: """ Args: hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from `compute_cu_seqlens`). + rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Bilinear corner indices into the position embedding table (from `compute_pos_embed_indices`). + bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Interpolation weights for the four bilinear corners (from `compute_pos_embed_indices`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + if embed_indices is None or bilinear_weights is None: + embed_indices, bilinear_weights = get_pos_embed_indices( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds - rotary_pos_emb = self.rot_pos_emb(grid_thw) + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) @@ -1249,16 +1173,6 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - for blk in self.blocks: hidden_states = blk( hidden_states, @@ -2001,6 +1915,7 @@ class Qwen3_5MoeForConditionalGeneration(Qwen3_5MoePreTrainedModel, GenerationMi # Reference: fix gemma3 grad acc #37208 accepts_loss_kwargs = False config: Qwen3_5MoeConfig + _tp_plan = {"lm_head": "colwise_gather_output"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py index 312b22bc88ed..f3b4b80aa3a6 100644 --- a/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py @@ -248,6 +248,8 @@ def __init__(self, config): class Qwen3_5MoeForConditionalGeneration(Qwen3VLMoeForConditionalGeneration): + _tp_plan = {"lm_head": "colwise_gather_output"} + def forward(self, **super_kwargs): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 5141ffc388c8..4518857b620b 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -52,6 +52,7 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_pos_embed_indices, get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import ( @@ -616,6 +617,79 @@ def forward( return outputs +def chunk_and_pad_features( + input_features: torch.Tensor, feature_lens: torch.Tensor, n_window: int +) -> tuple[torch.Tensor, torch.Tensor]: + """Split audio features into fixed-size chunks and pad to uniform length. + + Each audio sample is split into chunks of ``n_window * 2`` frames (the last + chunk may be shorter), then all chunks are right-padded to the longest chunk. + + Args: + input_features: ``(feature_dim, total_frames)`` concatenated audio features. + feature_lens: ``(batch_size,)`` per-sample frame counts. + n_window: half the target chunk size in frames. + + Returns: + ``padded_feature``: ``(num_chunks, feature_dim, max_chunk_len)`` padded chunks. + ``chunk_lengths``: ``(num_chunks,)`` actual length of each chunk before padding. + """ + chunk_num = torch.ceil(feature_lens / (n_window * 2)).long() + chunk_lengths = torch.full((chunk_num.sum(),), n_window * 2, dtype=torch.long, device=feature_lens.device) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (n_window * 2) + chunk_lengths = torch.where(chunk_lengths == 0, n_window * 2, chunk_lengths) + chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) + padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) + return padded_feature, chunk_lengths + + +def get_valid_indices(chunk_lengths: torch.Tensor) -> torch.Tensor: + """Compute flat indices of valid (non-padding) positions after CNN extraction. + + Args: + chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + + Returns: + ``(total_valid,)`` flat indices into the ``(num_chunks * max_len_after_cnn)`` grid. + """ + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + max_len_after_cnn = feature_lens_after_cnn.max().item() + mask = torch.arange(max_len_after_cnn, device=chunk_lengths.device) < feature_lens_after_cnn.unsqueeze(1) + return mask.flatten().nonzero().squeeze(-1) + + +def get_audio_cu_seqlens( + chunk_lengths: torch.Tensor, feature_lens: torch.Tensor, n_window_infer: int, n_window: int +) -> torch.Tensor: + """Compute cumulative sequence lengths for audio attention windowing. + + Splits each sample's post-CNN features into inference windows and returns + cumulative boundaries for flash-attention-style sequence packing. + + Args: + chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + feature_lens: ``(batch_size,)`` per-sample frame counts. + n_window_infer: inference window size (in raw frames). + n_window: half the chunk size (in raw frames). + + Returns: + ``(num_windows + 1,)`` int32 cumulative sequence boundaries. + """ + aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + max_len_after_cnn = feature_lens_after_cnn.max().item() + cu_chunk_lens = [0] + n_window_ratio = n_window_infer // (n_window * 2) + window_aftercnn = max_len_after_cnn * n_window_ratio + for cnn_len in aftercnn_lens: + cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) + remainder = cnn_len % window_aftercnn + if remainder != 0: + cu_chunk_lens += [remainder] + return torch.tensor(cu_chunk_lens, device=feature_lens.device).cumsum(-1, dtype=torch.int32) + + @auto_docstring( custom_intro=""" Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a @@ -673,65 +747,45 @@ def get_input_embeddings(self) -> nn.Module: def set_input_embeddings(self, value): self.conv2d1 = value - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if is_flash_attention_requested(self.config): - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - @merge_with_config_defaults @capture_outputs(tie_last_hidden_states=False) @auto_docstring def forward( self, - input_features, + input_features=None, feature_lens=None, - aftercnn_lens=None, - **kwargs, + padded_feature=None, + chunk_lengths=None, + valid_indices=None, + cu_seqlens=None, + **kwargs: Unpack[TransformersKwargs], ): r""" feature_lens (`torch.LongTensor` of shape `(batch_size,)`): mel length - aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`): - mel length after cnn + padded_feature (`torch.FloatTensor`, *optional*): + Precomputed padded audio chunks (from `chunk_and_pad_features`). + chunk_lengths (`torch.LongTensor`, *optional*): + Precomputed per-chunk lengths (from `chunk_and_pad_features`). + valid_indices (`torch.LongTensor`, *optional*): + Precomputed flat indices of valid post-CNN positions (from `get_valid_indices`). + cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths (from `get_audio_cu_seqlens`). """ - aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) - chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + if padded_feature is None: + padded_feature, chunk_lengths = chunk_and_pad_features(input_features, feature_lens, self.n_window) - chunk_lengths = torch.full((chunk_num.sum(),), self.n_window * 2, dtype=torch.long, device=feature_lens.device) - tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] - chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) - chunk_lengths[chunk_lengths == 0] = self.n_window * 2 + if valid_indices is None: + valid_indices = get_valid_indices(chunk_lengths) - chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) - padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) - feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) - padded_mask_after_cnn = nn.utils.rnn.pad_sequence( - [torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn], - batch_first=True, - ) + if cu_seqlens is None: + cu_seqlens = get_audio_cu_seqlens(chunk_lengths, feature_lens, self.n_window_infer, self.n_window) + + # Add channel dim for Conv2d: (num_chunks, mel_bins, time) -> (num_chunks, 1, mel_bins, time) padded_feature = padded_feature.unsqueeze(1) - # Split to chunk to avoid OOM during convolution - padded_embeds = [] - for chunk in padded_feature.split(self.conv_chunksize, dim=0): - padded_embed = F.gelu(self.conv2d1(chunk)) - padded_embed = F.gelu(self.conv2d2(padded_embed)) - padded_embed = F.gelu(self.conv2d3(padded_embed)) - padded_embeds.append(padded_embed) - padded_embed = torch.cat(padded_embeds, dim=0) + padded_embed = F.gelu(self.conv2d1(padded_feature)) + padded_embed = F.gelu(self.conv2d2(padded_embed)) + padded_embed = F.gelu(self.conv2d3(padded_embed)) b, c, f, t = padded_embed.size() padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)) @@ -741,22 +795,36 @@ def forward( .to(padded_embed.dtype) ) padded_embed = padded_embed + positional_embedding - hidden_states = padded_embed[padded_mask_after_cnn] - cu_chunk_lens = [0] - window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2)) - for cnn_len in aftercnn_lens: - cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) - remainder = cnn_len % window_aftercnn - if remainder != 0: - cu_chunk_lens += [remainder] - cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32) + hidden_states = torch.index_select(padded_embed.reshape(-1, padded_embed.shape[-1]), 0, valid_indices) + + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention mask only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if is_flash_attention_requested(self.config): + attention_mask = None + else: + seq_idx = torch.arange(hidden_states.shape[0], device=hidden_states.device) + block_ids = torch.searchsorted(cu_seqlens[1:], seq_idx, right=True) + same_block = block_ids.unsqueeze(0) == block_ids.unsqueeze(1) + attention_mask = ( + torch.full( + (hidden_states.shape[0], hidden_states.shape[0]), + torch.finfo(hidden_states.dtype).min, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + .masked_fill(same_block, 0.0) + .unsqueeze(0) + .unsqueeze(0) + ) for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, cu_seqlens, + attention_mask=attention_mask, ) - hidden_states = layer_outputs[0] hidden_states = self.ln_post(hidden_states) @@ -765,44 +833,6 @@ def forward( hidden_states = self.proj2(hidden_states) return BaseModelOutputWithPooling(last_hidden_state=hidden_states) - def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): - """ - Pads a sequence of tensors to their maximum length on indicated `padding_side`. - Then prepares a mask so that pad tokens are not attended to. - """ - max_len = tensor_len.max() - dim = tensor_list[0].shape[0] - padded_tensor = torch.full( - size=(len(tensor_list), dim, max_len), - fill_value=padding_value, - dtype=self.dtype, - device=tensor_list[0].device, - ) - - batch_mask = torch.zeros( - (len(tensor_len), max_len), - dtype=torch.long, - device=padded_tensor.device, - ) - for i, length in enumerate(tensor_len): - batch_mask[i, :length] = 1 - padded_tensor[i, :, :length] = tensor_list[i] - - feature_lens_after_cnn = (tensor_len - 1) // 2 + 1 - max_len_after_cnn = feature_lens_after_cnn.max() - batch_mask_after_cnn = torch.zeros( - (len(tensor_len), max_len_after_cnn), - dtype=torch.long, - device=padded_tensor.device, - ) - for i, length in enumerate(feature_lens_after_cnn): - batch_mask_after_cnn[i, :length] = 1 - return ( - padded_tensor, - batch_mask.unsqueeze(1), - batch_mask_after_cnn.bool(), - ) - # Ignore copy def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ @@ -950,10 +980,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: + return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Qwen3OmniMoeTextTopKRouter(nn.Module): @@ -1089,113 +1117,17 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - merge_size = self.spatial_merge_size - grid_thw_list = grid_thw.tolist() - - max_hw = max(max(h, w) for _, h, w in grid_thw_list) - freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) - device = freq_table.device - - total_tokens = sum(t * h * w for t, h, w in grid_thw_list) - pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) - - offset = 0 - for num_frames, height, width in grid_thw_list: - merged_h, merged_w = height // merge_size, width // merge_size - - block_rows = torch.arange(merged_h, device=device) # block row indices - block_cols = torch.arange(merged_w, device=device) # block col indices - intra_row = torch.arange(merge_size, device=device) # intra-block row offsets - intra_col = torch.arange(merge_size, device=device) # intra-block col offsets - - # Compute full-resolution positions - row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] - col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] - - row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - - coords = torch.stack((row_idx, col_idx), dim=-1) - - if num_frames > 1: - coords = coords.repeat(num_frames, 1) - - num_tokens = coords.shape[0] - pos_ids[offset : offset + num_tokens] = coords - offset += num_tokens - - embeddings = freq_table[pos_ids] # lookup rotary embeddings - embeddings = embeddings.flatten(1) - return embeddings - - def fast_pos_embed_interpolate(self, grid_thw): - grid_thw_list = grid_thw.tolist() - grid_ts = [row[0] for row in grid_thw_list] - grid_hs = [row[1] for row in grid_thw_list] - grid_ws = [row[2] for row in grid_thw_list] - device = self.pos_embed.weight.device - - idx_list = [[] for _ in range(4)] - weight_list = [[] for _ in range(4)] - - for t, h, w in grid_thw_list: - h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) - w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) - - h_idxs_floor = h_idxs.int() - w_idxs_floor = w_idxs.int() - h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - - dh = h_idxs - h_idxs_floor - dw = w_idxs - w_idxs_floor - - base_h = h_idxs_floor * self.num_grid_per_side - base_h_ceil = h_idxs_ceil * self.num_grid_per_side - - indices = [ - (base_h[None].T + w_idxs_floor[None]).flatten(), - (base_h[None].T + w_idxs_ceil[None]).flatten(), - (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), - (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), - ] - - weights = [ - ((1 - dh)[None].T * (1 - dw)[None]).flatten(), - ((1 - dh)[None].T * dw[None]).flatten(), - (dh[None].T * (1 - dw)[None]).flatten(), - (dh[None].T * dw[None]).flatten(), - ] - - for i in range(4): - idx_list[i].extend(indices[i].tolist()) - weight_list[i].extend(weights[i].tolist()) - - idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) - weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) - pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] - patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] - - patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) - - patch_pos_embeds_permute = [] - merge_size = self.config.spatial_merge_size - for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): - pos_embed = pos_embed.repeat(t, 1) - pos_embed = ( - pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) - .permute(0, 1, 3, 2, 4, 5) - .flatten(0, 4) - ) - patch_pos_embeds_permute.append(pos_embed) - patch_pos_embeds = torch.cat(patch_pos_embeds_permute) - return patch_pos_embeds - @merge_with_config_defaults @capture_outputs def forward( - self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + embed_indices: torch.Tensor | None = None, + bilinear_weights: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: """ Args: @@ -1203,16 +1135,34 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from `compute_cu_seqlens`). + rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Bilinear corner indices into the position embedding table (from `compute_pos_embed_indices`). + bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Interpolation weights for the four bilinear corners (from `compute_pos_embed_indices`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + if embed_indices is None or bilinear_weights is None: + embed_indices, bilinear_weights = get_pos_embed_indices( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds - rotary_pos_emb = self.rot_pos_emb(grid_thw) + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) @@ -1220,16 +1170,6 @@ def forward( emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks): hidden_states = blk( diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index d336784b3b49..55a0621bc68d 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -46,7 +46,7 @@ from ...processing_utils import ProcessorMixin, Unpack from ...tokenization_utils_base import TextInput from ...utils import auto_docstring, can_return_tuple, logging -from ...utils.generic import TransformersKwargs, merge_with_config_defaults +from ...utils.generic import TransformersKwargs, is_flash_attention_requested, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs from ...video_utils import VideoInput, make_batched_videos from ..mimi.modeling_mimi import MimiLayerScale @@ -103,6 +103,90 @@ logger = logging.get_logger(__name__) +def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor: + """Compute output lengths after the 3-layer CNN feature extractor with deepstack. + + Three stride-2 convolutions within each 100-frame block, plus 13 output frames + per full block from the deepstack path. + """ + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + return ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + + +def chunk_and_pad_features( + input_features: torch.Tensor, feature_lens: torch.Tensor, n_window: int +) -> tuple[torch.Tensor, torch.Tensor]: + """Split audio features into fixed-size chunks and pad to uniform length. + + Each audio sample is split into chunks of ``n_window * 2`` frames (the last + chunk may be shorter), then all chunks are right-padded to the longest chunk. + + Args: + input_features: ``(feature_dim, total_frames)`` concatenated audio features. + feature_lens: ``(batch_size,)`` per-sample frame counts. + n_window: half the target chunk size in frames. + + Returns: + ``padded_feature``: ``(num_chunks, feature_dim, max_chunk_len)`` padded chunks. + ``chunk_lengths``: ``(num_chunks,)`` actual length of each chunk before padding. + """ + chunk_num = torch.ceil(feature_lens / (n_window * 2)).long() + chunk_lengths = torch.full((chunk_num.sum(),), n_window * 2, dtype=torch.long, device=feature_lens.device) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (n_window * 2) + chunk_lengths = torch.where(chunk_lengths == 0, n_window * 2, chunk_lengths) + chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) + padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) + return padded_feature, chunk_lengths + + +def get_valid_indices(chunk_lengths: torch.Tensor) -> torch.Tensor: + """Compute flat indices of valid (non-padding) positions after CNN extraction. + + Args: + chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + + Returns: + ``(total_valid,)`` flat indices into the ``(num_chunks * max_len_after_cnn)`` grid. + """ + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + max_len_after_cnn = feature_lens_after_cnn.max().item() + mask = torch.arange(max_len_after_cnn, device=chunk_lengths.device) < feature_lens_after_cnn.unsqueeze(1) + return mask.flatten().nonzero().squeeze(-1) + + +def get_audio_cu_seqlens( + chunk_lengths: torch.Tensor, feature_lens: torch.Tensor, n_window_infer: int, n_window: int +) -> torch.Tensor: + """Compute cumulative sequence lengths for audio attention windowing. + + Splits each sample's post-CNN features into inference windows and returns + cumulative boundaries for flash-attention-style sequence packing. + + Args: + chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + feature_lens: ``(batch_size,)`` per-sample frame counts. + n_window_infer: inference window size (in raw frames). + n_window: half the chunk size (in raw frames). + + Returns: + ``(num_windows + 1,)`` int32 cumulative sequence boundaries. + """ + aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + max_len_after_cnn = feature_lens_after_cnn.max().item() + cu_chunk_lens = [0] + n_window_ratio = n_window_infer // (n_window * 2) + window_aftercnn = max_len_after_cnn * n_window_ratio + for cnn_len in aftercnn_lens: + cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) + remainder = cnn_len % window_aftercnn + if remainder != 0: + cu_chunk_lens += [remainder] + return torch.tensor(cu_chunk_lens, device=feature_lens.device).cumsum(-1, dtype=torch.int32) + + @dataclass @auto_docstring class BaseModelOutputWithDeepstackFeatures(BaseModelOutputWithPooling): @@ -892,37 +976,45 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.conv2d1 = value + @merge_with_config_defaults + @capture_outputs(tie_last_hidden_states=False) + @auto_docstring def forward( self, - input_features, + input_features=None, feature_lens=None, - aftercnn_lens=None, - **kwargs, + padded_feature=None, + chunk_lengths=None, + valid_indices=None, + cu_seqlens=None, + **kwargs: Unpack[TransformersKwargs], ): - aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) - chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() - - chunk_lengths = torch.full((chunk_num.sum(),), self.n_window * 2, dtype=torch.long, device=feature_lens.device) - tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] - chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) - chunk_lengths[chunk_lengths == 0] = self.n_window * 2 - - chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) - padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) - feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) - padded_mask_after_cnn = nn.utils.rnn.pad_sequence( - [torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn], - batch_first=True, - ) + r""" + feature_lens (`torch.LongTensor` of shape `(batch_size,)`): + mel length + padded_feature (`torch.FloatTensor`, *optional*): + Precomputed padded audio chunks (from `chunk_and_pad_features`). + chunk_lengths (`torch.LongTensor`, *optional*): + Precomputed per-chunk lengths (from `chunk_and_pad_features`). + valid_indices (`torch.LongTensor`, *optional*): + Precomputed flat indices of valid post-CNN positions (from `get_valid_indices`). + cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths (from `get_audio_cu_seqlens`). + """ + if padded_feature is None: + padded_feature, chunk_lengths = chunk_and_pad_features(input_features, feature_lens, self.n_window) + + if valid_indices is None: + valid_indices = get_valid_indices(chunk_lengths) + + if cu_seqlens is None: + cu_seqlens = get_audio_cu_seqlens(chunk_lengths, feature_lens, self.n_window_infer, self.n_window) + + # Add channel dim for Conv2d: (num_chunks, mel_bins, time) -> (num_chunks, 1, mel_bins, time) padded_feature = padded_feature.unsqueeze(1) - # Split to chunk to avoid OOM during convolution - padded_embeds = [] - for chunk in padded_feature.split(self.conv_chunksize, dim=0): - padded_embed = F.gelu(self.conv2d1(chunk)) - padded_embed = F.gelu(self.conv2d2(padded_embed)) - padded_embed = F.gelu(self.conv2d3(padded_embed)) - padded_embeds.append(padded_embed) - padded_embed = torch.cat(padded_embeds, dim=0) + padded_embed = F.gelu(self.conv2d1(padded_feature)) + padded_embed = F.gelu(self.conv2d2(padded_embed)) + padded_embed = F.gelu(self.conv2d3(padded_embed)) b, c, f, t = padded_embed.size() padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)) @@ -932,22 +1024,36 @@ def forward( .to(padded_embed.dtype) ) padded_embed = padded_embed + positional_embedding - hidden_states = padded_embed[padded_mask_after_cnn] - cu_chunk_lens = [0] - window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2)) - for cnn_len in aftercnn_lens: - cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) - remainder = cnn_len % window_aftercnn - if remainder != 0: - cu_chunk_lens += [remainder] - cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32) + hidden_states = torch.index_select(padded_embed.reshape(-1, padded_embed.shape[-1]), 0, valid_indices) + + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention mask only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if is_flash_attention_requested(self.config): + attention_mask = None + else: + seq_idx = torch.arange(hidden_states.shape[0], device=hidden_states.device) + block_ids = torch.searchsorted(cu_seqlens[1:], seq_idx, right=True) + same_block = block_ids.unsqueeze(0) == block_ids.unsqueeze(1) + attention_mask = ( + torch.full( + (hidden_states.shape[0], hidden_states.shape[0]), + torch.finfo(hidden_states.dtype).min, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + .masked_fill(same_block, 0.0) + .unsqueeze(0) + .unsqueeze(0) + ) for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, cu_seqlens, + attention_mask=attention_mask, ) - hidden_states = layer_outputs[0] hidden_states = self.ln_post(hidden_states) diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 73678ee8c736..ee654d205f66 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -25,7 +25,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from ... import initialization as init from ...activations import ACT2FN @@ -38,6 +37,8 @@ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_pos_embed_indices, get_vision_cu_seqlens +from ...modeling_vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults @@ -99,10 +100,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: + return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Qwen3VLVisionPatchMerger(nn.Module): @@ -659,113 +658,17 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - merge_size = self.spatial_merge_size - grid_thw_list = grid_thw.tolist() - - max_hw = max(max(h, w) for _, h, w in grid_thw_list) - freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) - device = freq_table.device - - total_tokens = sum(t * h * w for t, h, w in grid_thw_list) - pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) - - offset = 0 - for num_frames, height, width in grid_thw_list: - merged_h, merged_w = height // merge_size, width // merge_size - - block_rows = torch.arange(merged_h, device=device) # block row indices - block_cols = torch.arange(merged_w, device=device) # block col indices - intra_row = torch.arange(merge_size, device=device) # intra-block row offsets - intra_col = torch.arange(merge_size, device=device) # intra-block col offsets - - # Compute full-resolution positions - row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] - col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] - - row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - - coords = torch.stack((row_idx, col_idx), dim=-1) - - if num_frames > 1: - coords = coords.repeat(num_frames, 1) - - num_tokens = coords.shape[0] - pos_ids[offset : offset + num_tokens] = coords - offset += num_tokens - - embeddings = freq_table[pos_ids] # lookup rotary embeddings - embeddings = embeddings.flatten(1) - return embeddings - - def fast_pos_embed_interpolate(self, grid_thw): - grid_thw_list = grid_thw.tolist() - grid_ts = [row[0] for row in grid_thw_list] - grid_hs = [row[1] for row in grid_thw_list] - grid_ws = [row[2] for row in grid_thw_list] - device = self.pos_embed.weight.device - - idx_list = [[] for _ in range(4)] - weight_list = [[] for _ in range(4)] - - for t, h, w in grid_thw_list: - h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) - w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) - - h_idxs_floor = h_idxs.int() - w_idxs_floor = w_idxs.int() - h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - - dh = h_idxs - h_idxs_floor - dw = w_idxs - w_idxs_floor - - base_h = h_idxs_floor * self.num_grid_per_side - base_h_ceil = h_idxs_ceil * self.num_grid_per_side - - indices = [ - (base_h[None].T + w_idxs_floor[None]).flatten(), - (base_h[None].T + w_idxs_ceil[None]).flatten(), - (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), - (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), - ] - - weights = [ - ((1 - dh)[None].T * (1 - dw)[None]).flatten(), - ((1 - dh)[None].T * dw[None]).flatten(), - (dh[None].T * (1 - dw)[None]).flatten(), - (dh[None].T * dw[None]).flatten(), - ] - - for i in range(4): - idx_list[i].extend(indices[i].tolist()) - weight_list[i].extend(weights[i].tolist()) - - idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) - weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) - pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] - patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] - - patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) - - patch_pos_embeds_permute = [] - merge_size = self.config.spatial_merge_size - for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): - pos_embed = pos_embed.repeat(t, 1) - pos_embed = ( - pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) - .permute(0, 1, 3, 2, 4, 5) - .flatten(0, 4) - ) - patch_pos_embeds_permute.append(pos_embed) - patch_pos_embeds = torch.cat(patch_pos_embeds_permute) - return patch_pos_embeds - @merge_with_config_defaults @capture_outputs def forward( - self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + embed_indices: torch.Tensor | None = None, + bilinear_weights: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: """ Args: @@ -773,16 +676,34 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from `compute_cu_seqlens`). + rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Bilinear corner indices into the position embedding table (from `compute_pos_embed_indices`). + bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Interpolation weights for the four bilinear corners (from `compute_pos_embed_indices`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + if embed_indices is None or bilinear_weights is None: + embed_indices, bilinear_weights = get_pos_embed_indices( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds - rotary_pos_emb = self.rot_pos_emb(grid_thw) + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) @@ -790,16 +711,6 @@ def forward( emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks): hidden_states = blk( diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 74d887726119..4433a0b75e15 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -20,7 +20,6 @@ import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from huggingface_hub.dataclasses import strict from ... import initialization as init @@ -34,6 +33,8 @@ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_rope_utils import RopeParameters, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_pos_embed_indices, get_vision_cu_seqlens +from ...modeling_vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids from ...processing_utils import ProcessingKwargs, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, can_return_tuple, logging @@ -445,113 +446,17 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - merge_size = self.spatial_merge_size - grid_thw_list = grid_thw.tolist() - - max_hw = max(max(h, w) for _, h, w in grid_thw_list) - freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) - device = freq_table.device - - total_tokens = sum(t * h * w for t, h, w in grid_thw_list) - pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) - - offset = 0 - for num_frames, height, width in grid_thw_list: - merged_h, merged_w = height // merge_size, width // merge_size - - block_rows = torch.arange(merged_h, device=device) # block row indices - block_cols = torch.arange(merged_w, device=device) # block col indices - intra_row = torch.arange(merge_size, device=device) # intra-block row offsets - intra_col = torch.arange(merge_size, device=device) # intra-block col offsets - - # Compute full-resolution positions - row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] - col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] - - row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - - coords = torch.stack((row_idx, col_idx), dim=-1) - - if num_frames > 1: - coords = coords.repeat(num_frames, 1) - - num_tokens = coords.shape[0] - pos_ids[offset : offset + num_tokens] = coords - offset += num_tokens - - embeddings = freq_table[pos_ids] # lookup rotary embeddings - embeddings = embeddings.flatten(1) - return embeddings - - def fast_pos_embed_interpolate(self, grid_thw): - grid_thw_list = grid_thw.tolist() - grid_ts = [row[0] for row in grid_thw_list] - grid_hs = [row[1] for row in grid_thw_list] - grid_ws = [row[2] for row in grid_thw_list] - device = self.pos_embed.weight.device - - idx_list = [[] for _ in range(4)] - weight_list = [[] for _ in range(4)] - - for t, h, w in grid_thw_list: - h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) - w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) - - h_idxs_floor = h_idxs.int() - w_idxs_floor = w_idxs.int() - h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - - dh = h_idxs - h_idxs_floor - dw = w_idxs - w_idxs_floor - - base_h = h_idxs_floor * self.num_grid_per_side - base_h_ceil = h_idxs_ceil * self.num_grid_per_side - - indices = [ - (base_h[None].T + w_idxs_floor[None]).flatten(), - (base_h[None].T + w_idxs_ceil[None]).flatten(), - (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), - (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), - ] - - weights = [ - ((1 - dh)[None].T * (1 - dw)[None]).flatten(), - ((1 - dh)[None].T * dw[None]).flatten(), - (dh[None].T * (1 - dw)[None]).flatten(), - (dh[None].T * dw[None]).flatten(), - ] - - for i in range(4): - idx_list[i].extend(indices[i].tolist()) - weight_list[i].extend(weights[i].tolist()) - - idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) - weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) - pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] - patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] - - patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) - - patch_pos_embeds_permute = [] - merge_size = self.config.spatial_merge_size - for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): - pos_embed = pos_embed.repeat(t, 1) - pos_embed = ( - pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) - .permute(0, 1, 3, 2, 4, 5) - .flatten(0, 4) - ) - patch_pos_embeds_permute.append(pos_embed) - patch_pos_embeds = torch.cat(patch_pos_embeds_permute) - return patch_pos_embeds - @merge_with_config_defaults @capture_outputs def forward( - self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + embed_indices: torch.Tensor | None = None, + bilinear_weights: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: """ Args: @@ -559,16 +464,34 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from `compute_cu_seqlens`). + rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Bilinear corner indices into the position embedding table (from `compute_pos_embed_indices`). + bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Interpolation weights for the four bilinear corners (from `compute_pos_embed_indices`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + if embed_indices is None or bilinear_weights is None: + embed_indices, bilinear_weights = get_pos_embed_indices( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds - rotary_pos_emb = self.rot_pos_emb(grid_thw) + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) @@ -576,16 +499,6 @@ def forward( emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks): hidden_states = blk( diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 6d4c68c1a752..b9f5fa635e83 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -43,6 +43,8 @@ from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_pos_embed_indices, get_vision_cu_seqlens +from ...modeling_vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults @@ -399,10 +401,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: + return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) def apply_rotary_pos_emb_vision( @@ -643,113 +643,17 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - merge_size = self.spatial_merge_size - grid_thw_list = grid_thw.tolist() - - max_hw = max(max(h, w) for _, h, w in grid_thw_list) - freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) - device = freq_table.device - - total_tokens = sum(t * h * w for t, h, w in grid_thw_list) - pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) - - offset = 0 - for num_frames, height, width in grid_thw_list: - merged_h, merged_w = height // merge_size, width // merge_size - - block_rows = torch.arange(merged_h, device=device) # block row indices - block_cols = torch.arange(merged_w, device=device) # block col indices - intra_row = torch.arange(merge_size, device=device) # intra-block row offsets - intra_col = torch.arange(merge_size, device=device) # intra-block col offsets - - # Compute full-resolution positions - row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] - col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] - - row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - - coords = torch.stack((row_idx, col_idx), dim=-1) - - if num_frames > 1: - coords = coords.repeat(num_frames, 1) - - num_tokens = coords.shape[0] - pos_ids[offset : offset + num_tokens] = coords - offset += num_tokens - - embeddings = freq_table[pos_ids] # lookup rotary embeddings - embeddings = embeddings.flatten(1) - return embeddings - - def fast_pos_embed_interpolate(self, grid_thw): - grid_thw_list = grid_thw.tolist() - grid_ts = [row[0] for row in grid_thw_list] - grid_hs = [row[1] for row in grid_thw_list] - grid_ws = [row[2] for row in grid_thw_list] - device = self.pos_embed.weight.device - - idx_list = [[] for _ in range(4)] - weight_list = [[] for _ in range(4)] - - for t, h, w in grid_thw_list: - h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) - w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) - - h_idxs_floor = h_idxs.int() - w_idxs_floor = w_idxs.int() - h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - - dh = h_idxs - h_idxs_floor - dw = w_idxs - w_idxs_floor - - base_h = h_idxs_floor * self.num_grid_per_side - base_h_ceil = h_idxs_ceil * self.num_grid_per_side - - indices = [ - (base_h[None].T + w_idxs_floor[None]).flatten(), - (base_h[None].T + w_idxs_ceil[None]).flatten(), - (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), - (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), - ] - - weights = [ - ((1 - dh)[None].T * (1 - dw)[None]).flatten(), - ((1 - dh)[None].T * dw[None]).flatten(), - (dh[None].T * (1 - dw)[None]).flatten(), - (dh[None].T * dw[None]).flatten(), - ] - - for i in range(4): - idx_list[i].extend(indices[i].tolist()) - weight_list[i].extend(weights[i].tolist()) - - idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) - weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) - pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] - patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] - - patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) - - patch_pos_embeds_permute = [] - merge_size = self.config.spatial_merge_size - for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): - pos_embed = pos_embed.repeat(t, 1) - pos_embed = ( - pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) - .permute(0, 1, 3, 2, 4, 5) - .flatten(0, 4) - ) - patch_pos_embeds_permute.append(pos_embed) - patch_pos_embeds = torch.cat(patch_pos_embeds_permute) - return patch_pos_embeds - @merge_with_config_defaults @capture_outputs def forward( - self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, + embed_indices: torch.Tensor | None = None, + bilinear_weights: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: """ Args: @@ -757,16 +661,34 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from `compute_cu_seqlens`). + rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Bilinear corner indices into the position embedding table (from `compute_pos_embed_indices`). + bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Interpolation weights for the four bilinear corners (from `compute_pos_embed_indices`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + if embed_indices is None or bilinear_weights is None: + embed_indices, bilinear_weights = get_pos_embed_indices( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds - rotary_pos_emb = self.rot_pos_emb(grid_thw) + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) @@ -774,16 +696,6 @@ def forward( emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks): hidden_states = blk( diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index 26d89b313167..338ddfdeeb9c 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -32,6 +32,7 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import is_flash_attention_requested, merge_with_config_defaults @@ -50,39 +51,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, grid_thw, merge_sizes) -> tuple[torch.Tensor, torch.Tensor]: - pos_ids = [] - for (t, h, w), merge_size in zip(grid_thw, merge_sizes): - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // merge_size, - merge_size, - w // merge_size, - merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // merge_size, - merge_size, - w // merge_size, - merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_thw = grid_thw[:, 1:].max() - - seq = torch.arange(max_grid_thw, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - rotary_pos_emb_full = torch.outer(seq, self.inv_freq) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - - return (emb.cos(), emb.sin()) + def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: + return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class VideoLlama3VisionEmbeddings(nn.Module): @@ -444,6 +414,8 @@ def forward( pixel_values: torch.Tensor, grid_thw: torch.Tensor, merge_sizes: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutput: r""" @@ -451,18 +423,20 @@ def forward( The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values. merge_sizes (`torch.Tensor` of shape `(num_images_or_videos,)`): The spatial downsampling ratio of each image or video feature. + cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). + rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). """ - position_embeddings = self.rotary_pos_emb(grid_thw, merge_sizes) - - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, merge_sizes) + + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) hidden_states = self.embeddings(pixel_values.type(self.dtype)) encoder_outputs: BaseModelOutput = self.encoder( diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index 4eef74580c87..1aad0f3daefa 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -37,6 +37,7 @@ ) from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import ImagesKwargs, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import ( @@ -122,39 +123,7 @@ def __post_init__(self, **kwargs): class VideoLlama3VisionRotaryEmbedding(VisionRotaryEmbedding): - def forward(self, grid_thw, merge_sizes) -> tuple[torch.Tensor, torch.Tensor]: - pos_ids = [] - for (t, h, w), merge_size in zip(grid_thw, merge_sizes): - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // merge_size, - merge_size, - w // merge_size, - merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // merge_size, - merge_size, - w // merge_size, - merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_thw = grid_thw[:, 1:].max() - - seq = torch.arange(max_grid_thw, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - rotary_pos_emb_full = torch.outer(seq, self.inv_freq) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - - return (emb.cos(), emb.sin()) + pass class VideoLlama3VisionEmbeddings(nn.Module): @@ -407,6 +376,8 @@ def forward( pixel_values: torch.Tensor, grid_thw: torch.Tensor, merge_sizes: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutput: r""" @@ -414,18 +385,20 @@ def forward( The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values. merge_sizes (`torch.Tensor` of shape `(num_images_or_videos,)`): The spatial downsampling ratio of each image or video feature. + cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). + rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). """ - position_embeddings = self.rotary_pos_emb(grid_thw, merge_sizes) - - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) + if rotary_pos_ids is None: + rotary_pos_ids = get_rotary_pos_ids(grid_thw, merge_sizes) + + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) hidden_states = self.embeddings(pixel_values.type(self.dtype)) encoder_outputs: BaseModelOutput = self.encoder( From fe46ba2ecde8eb6acb4b9e42f14223e74f2c8888 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 10:51:19 +0200 Subject: [PATCH 02/56] Fix stale compute_ docstring references to match actual function names Co-Authored-By: Claude Opus 4.6 (1M context) --- src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 6 +++--- src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py | 6 +++--- src/transformers/models/qwen3_5/modeling_qwen3_5.py | 8 ++++---- src/transformers/models/qwen3_5/modular_qwen3_5.py | 6 +++--- .../models/qwen3_5_moe/modeling_qwen3_5_moe.py | 8 ++++---- .../models/qwen3_omni_moe/modeling_qwen3_omni_moe.py | 6 +++--- src/transformers/models/qwen3_vl/modeling_qwen3_vl.py | 6 +++--- src/transformers/models/qwen3_vl/modular_qwen3_vl.py | 6 +++--- .../models/qwen3_vl_moe/modeling_qwen3_vl_moe.py | 6 +++--- 9 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 96873304c832..ecdde65ae4b8 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -396,11 +396,11 @@ def forward( grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `compute_cu_seqlens`). + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). window_index (`torch.Tensor`, *optional*): - Precomputed window reordering index (from `compute_window_index`). + Precomputed window reordering index (from `get_window_index`). cu_window_seqlens (`torch.Tensor`, *optional*): - Precomputed window cumulative sequence lengths (from `compute_window_index`). + Precomputed window cumulative sequence lengths (from `get_window_index`). rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 406e75a10049..3d489cd6f66e 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -237,11 +237,11 @@ def forward( grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `compute_cu_seqlens`). + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). window_index (`torch.Tensor`, *optional*): - Precomputed window reordering index (from `compute_window_index`). + Precomputed window reordering index (from `get_window_index`). cu_window_seqlens (`torch.Tensor`, *optional*): - Precomputed window cumulative sequence lengths (from `compute_window_index`). + Precomputed window cumulative sequence lengths (from `get_window_index`). rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index 9ab2730c9e88..555497206644 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -293,7 +293,7 @@ def torch_chunk_gated_delta_rule( # for each chunk for i in range(0, total_sequence_length // chunk_size): q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] - attn = q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state v_new = v_i - v_prime attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state @@ -1046,13 +1046,13 @@ def forward( grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `compute_cu_seqlens`). + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices into the position embedding table (from `compute_pos_embed_indices`). + Bilinear corner indices into the position embedding table (from `get_pos_embed_indices`). bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Interpolation weights for the four bilinear corners (from `compute_pos_embed_indices`). + Interpolation weights for the four bilinear corners (from `get_pos_embed_indices`). Returns: `torch.Tensor`: hidden_states. diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index ab1304871979..1e075af3404a 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -439,13 +439,13 @@ def forward( grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `compute_cu_seqlens`). + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices into the position embedding table (from `compute_pos_embed_indices`). + Bilinear corner indices into the position embedding table (from `get_pos_embed_indices`). bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Interpolation weights for the four bilinear corners (from `compute_pos_embed_indices`). + Interpolation weights for the four bilinear corners (from `get_pos_embed_indices`). Returns: `torch.Tensor`: hidden_states. diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index c0e72f6d1076..8d4327e0cc79 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -294,7 +294,7 @@ def torch_chunk_gated_delta_rule( # for each chunk for i in range(0, total_sequence_length // chunk_size): q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] - attn = q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state v_new = v_i - v_prime attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state @@ -1139,13 +1139,13 @@ def forward( grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `compute_cu_seqlens`). + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices into the position embedding table (from `compute_pos_embed_indices`). + Bilinear corner indices into the position embedding table (from `get_pos_embed_indices`). bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Interpolation weights for the four bilinear corners (from `compute_pos_embed_indices`). + Interpolation weights for the four bilinear corners (from `get_pos_embed_indices`). Returns: `torch.Tensor`: hidden_states. diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 4518857b620b..5b09794e08ee 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1136,13 +1136,13 @@ def forward( grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `compute_cu_seqlens`). + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices into the position embedding table (from `compute_pos_embed_indices`). + Bilinear corner indices into the position embedding table (from `get_pos_embed_indices`). bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Interpolation weights for the four bilinear corners (from `compute_pos_embed_indices`). + Interpolation weights for the four bilinear corners (from `get_pos_embed_indices`). Returns: `torch.Tensor`: hidden_states. diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index ee654d205f66..7808d7e520b1 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -677,13 +677,13 @@ def forward( grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `compute_cu_seqlens`). + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices into the position embedding table (from `compute_pos_embed_indices`). + Bilinear corner indices into the position embedding table (from `get_pos_embed_indices`). bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Interpolation weights for the four bilinear corners (from `compute_pos_embed_indices`). + Interpolation weights for the four bilinear corners (from `get_pos_embed_indices`). Returns: `torch.Tensor`: hidden_states. diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 4433a0b75e15..df69111d99d7 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -465,13 +465,13 @@ def forward( grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `compute_cu_seqlens`). + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices into the position embedding table (from `compute_pos_embed_indices`). + Bilinear corner indices into the position embedding table (from `get_pos_embed_indices`). bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Interpolation weights for the four bilinear corners (from `compute_pos_embed_indices`). + Interpolation weights for the four bilinear corners (from `get_pos_embed_indices`). Returns: `torch.Tensor`: hidden_states. diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index b9f5fa635e83..ae80e02edd8a 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -662,13 +662,13 @@ def forward( grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `compute_cu_seqlens`). + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices into the position embedding table (from `compute_pos_embed_indices`). + Bilinear corner indices into the position embedding table (from `get_pos_embed_indices`). bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Interpolation weights for the four bilinear corners (from `compute_pos_embed_indices`). + Interpolation weights for the four bilinear corners (from `get_pos_embed_indices`). Returns: `torch.Tensor`: hidden_states. From 84439a045ea1851953e986a141f2492d99d3be09 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 10:57:34 +0200 Subject: [PATCH 03/56] =?UTF-8?q?Revert=20mlcd=20changes=20=E2=80=94=20not?= =?UTF-8?q?=20part=20of=20this=20PR?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.6 (1M context) --- src/transformers/models/mlcd/modeling_mlcd.py | 138 ++++++++++++------ src/transformers/models/mlcd/modular_mlcd.py | 110 +++++++++----- 2 files changed, 167 insertions(+), 81 deletions(-) diff --git a/src/transformers/models/mlcd/modeling_mlcd.py b/src/transformers/models/mlcd/modeling_mlcd.py index a9e6ff9a930a..efc0bb807d2d 100644 --- a/src/transformers/models/mlcd/modeling_mlcd.py +++ b/src/transformers/models/mlcd/modeling_mlcd.py @@ -59,8 +59,37 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: - return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) + def forward(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor: + """ + Calculate the Rotary Position Embedding (RoPE) for MLCDVisionModel based on the grid size. + + Args: + num_patches_height (int): Number of patches in the height dimension. + num_patches_width (int): Number of patches in the width dimension. + + Returns: + torch.Tensor: Rotary positional embeddings for the given grid size. + """ + # Generate position IDs for height and width dimensions + hpos_ids = ( + torch.arange(num_patches_height, device=self.inv_freq.device).unsqueeze(1).expand(-1, num_patches_width) + ) + wpos_ids = ( + torch.arange(num_patches_width, device=self.inv_freq.device).unsqueeze(0).expand(num_patches_height, -1) + ) + + # Flatten and stack the position IDs + pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1) + + # Generate the full rotary positional embeddings for the maximum grid size + max_grid_size = max(num_patches_height, num_patches_width) + seq = torch.arange(max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + rotary_pos_emb_full = torch.outer(seq, self.inv_freq) + + # Select and flatten the embeddings based on the position IDs + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + + return rotary_pos_emb class MLCDVisionEmbeddings(nn.Module): @@ -370,8 +399,7 @@ def forward( @auto_docstring class MLCDPreTrainedModel(PreTrainedModel): config: MLCDVisionConfig - base_model_prefix = "vision_model" - _no_split_modules = ["MLCDEncoderLayer"] + base_model_prefix = "mlcd" supports_gradient_checkpointing = True accepts_loss_kwargs = False _supports_flash_attn = True @@ -406,7 +434,7 @@ def _init_weights(self, module): fc_std = (2 * module.config.hidden_size) ** -0.5 * factor init.normal_(module.fc1.weight, std=fc_std) init.normal_(module.fc2.weight, std=in_proj_std) - elif isinstance(module, MLCDVisionModel): + elif isinstance(module, MLCDVisionTransformer): factor = self.config.initializer_factor pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std) @@ -420,19 +448,15 @@ def _init_weights(self, module): init.copy_(module.inv_freq, inv_freq) -@auto_docstring( - custom_intro=""" - The vision model from M_L_C_D without any head or projection on top. - """ -) -class MLCDVisionModel(MLCDPreTrainedModel): +class MLCDVisionTransformer(MLCDPreTrainedModel): config: MLCDVisionConfig main_input_name = "pixel_values" input_modalities = ("image",) - _input_embed_layer = "patch_embedding" + _no_split_modules = ["MLCDEncoderLayer"] def __init__(self, config: MLCDVisionConfig): super().__init__(config) + self.config = config embed_dim = config.hidden_size self.embeddings = MLCDVisionEmbeddings(config) @@ -451,43 +475,13 @@ def forward( pixel_values: torch.FloatTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: - r""" - Example: - - ```python - >>> import httpx - >>> from io import BytesIO - >>> from PIL import Image - >>> from transformers import AutoProcessor, MLCDVisionModel - >>> model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448") - >>> processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> with httpx.stream("GET", url) as response: - ... image = Image.open(BytesIO(response.read())) - >>> inputs = processor(images=image, return_tensors="pt") - - >>> with torch.no_grad(): - ... outputs = model(**inputs, output_attentions=True) - - >>> features = outputs.last_hidden_state - >>> print(f"Extracted features shape: {features.shape}") - >>> print(f"Number of attention layers: {len(outputs.attentions)}") - >>> print(f"Attention shape: {outputs.attentions[0].shape}") - ```""" if pixel_values is None: raise ValueError("You have to specify pixel_values") num_patches_height = pixel_values.shape[-2] // self.config.patch_size num_patches_width = pixel_values.shape[-1] // self.config.patch_size - hpos_ids = ( - torch.arange(num_patches_height, device=pixel_values.device).unsqueeze(1).expand(-1, num_patches_width) - ) - wpos_ids = ( - torch.arange(num_patches_width, device=pixel_values.device).unsqueeze(0).expand(num_patches_height, -1) - ) - pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1) - rotary_pos_emb = self.vision_rotary_embedding(pos_ids) + rotary_pos_emb = self.vision_rotary_embedding(num_patches_height, num_patches_width) + rotary_pos_emb = rotary_pos_emb.to(self.class_pos_emb.device) rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -511,4 +505,60 @@ def forward( ) +@auto_docstring( + custom_intro=""" + The vision model from M_L_C_D without any head or projection on top. + """ +) +class MLCDVisionModel(MLCDPreTrainedModel): + config: MLCDVisionConfig + main_input_name = "pixel_values" + input_modalities = ("image",) + _no_split_modules = ["MLCDEncoderLayer"] + + def __init__(self, config: MLCDVisionConfig): + super().__init__(config) + self.vision_model = MLCDVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + Example: + + ```python + >>> import httpx + >>> from io import BytesIO + >>> from PIL import Image + >>> from transformers import AutoProcessor, MLCDVisionModel + >>> model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448") + >>> processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> with httpx.stream("GET", url) as response: + ... image = Image.open(BytesIO(response.read())) + >>> inputs = processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs, output_attentions=True) + + >>> features = outputs.last_hidden_state + >>> print(f"Extracted features shape: {features.shape}") + >>> print(f"Number of attention layers: {len(outputs.attentions)}") + >>> print(f"Attention shape: {outputs.attentions[0].shape}") + ```""" + return self.vision_model( + pixel_values=pixel_values, + **kwargs, + ) + + __all__ = ["MLCDPreTrainedModel", "MLCDVisionModel"] diff --git a/src/transformers/models/mlcd/modular_mlcd.py b/src/transformers/models/mlcd/modular_mlcd.py index a1c291211730..01641c7d5ce1 100644 --- a/src/transformers/models/mlcd/modular_mlcd.py +++ b/src/transformers/models/mlcd/modular_mlcd.py @@ -30,6 +30,7 @@ CLIPEncoderLayer, CLIPVisionEmbeddings, CLIPVisionModel, + CLIPVisionTransformer, ) from ..llama.modeling_llama import eager_attention_forward from ..qwen2_vl.modeling_qwen2_vl import VisionRotaryEmbedding, apply_rotary_pos_emb_vision @@ -83,7 +84,37 @@ class MLCDMLP(CLIPMLP): class MLCDRotaryEmbedding(VisionRotaryEmbedding): - pass + def forward(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor: + """ + Calculate the Rotary Position Embedding (RoPE) for MLCDVisionModel based on the grid size. + + Args: + num_patches_height (int): Number of patches in the height dimension. + num_patches_width (int): Number of patches in the width dimension. + + Returns: + torch.Tensor: Rotary positional embeddings for the given grid size. + """ + # Generate position IDs for height and width dimensions + hpos_ids = ( + torch.arange(num_patches_height, device=self.inv_freq.device).unsqueeze(1).expand(-1, num_patches_width) + ) + wpos_ids = ( + torch.arange(num_patches_width, device=self.inv_freq.device).unsqueeze(0).expand(num_patches_height, -1) + ) + + # Flatten and stack the position IDs + pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1) + + # Generate the full rotary positional embeddings for the maximum grid size + max_grid_size = max(num_patches_height, num_patches_width) + seq = torch.arange(max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + rotary_pos_emb_full = torch.outer(seq, self.inv_freq) + + # Select and flatten the embeddings based on the position IDs + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + + return rotary_pos_emb class MLCDVisionEmbeddings(CLIPVisionEmbeddings): @@ -258,8 +289,7 @@ def forward( @auto_docstring class MLCDPreTrainedModel(PreTrainedModel): config: MLCDVisionConfig - base_model_prefix = "vision_model" - _no_split_modules = ["MLCDEncoderLayer"] + base_model_prefix = "mlcd" supports_gradient_checkpointing = True accepts_loss_kwargs = False _supports_flash_attn = True @@ -294,7 +324,7 @@ def _init_weights(self, module): fc_std = (2 * module.config.hidden_size) ** -0.5 * factor init.normal_(module.fc1.weight, std=fc_std) init.normal_(module.fc2.weight, std=in_proj_std) - elif isinstance(module, MLCDVisionModel): + elif isinstance(module, MLCDVisionTransformer): factor = self.config.initializer_factor pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std) @@ -308,7 +338,7 @@ def _init_weights(self, module): init.copy_(module.inv_freq, inv_freq) -class MLCDVisionModel(CLIPVisionModel): +class MLCDVisionTransformer(CLIPVisionTransformer): def __init__(self, config: MLCDVisionConfig): super().__init__(config) self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2) @@ -319,43 +349,13 @@ def forward( pixel_values: torch.FloatTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: - r""" - Example: - - ```python - >>> import httpx - >>> from io import BytesIO - >>> from PIL import Image - >>> from transformers import AutoProcessor, MLCDVisionModel - >>> model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448") - >>> processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448") - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> with httpx.stream("GET", url) as response: - ... image = Image.open(BytesIO(response.read())) - >>> inputs = processor(images=image, return_tensors="pt") - - >>> with torch.no_grad(): - ... outputs = model(**inputs, output_attentions=True) - - >>> features = outputs.last_hidden_state - >>> print(f"Extracted features shape: {features.shape}") - >>> print(f"Number of attention layers: {len(outputs.attentions)}") - >>> print(f"Attention shape: {outputs.attentions[0].shape}") - ```""" if pixel_values is None: raise ValueError("You have to specify pixel_values") num_patches_height = pixel_values.shape[-2] // self.config.patch_size num_patches_width = pixel_values.shape[-1] // self.config.patch_size - hpos_ids = ( - torch.arange(num_patches_height, device=pixel_values.device).unsqueeze(1).expand(-1, num_patches_width) - ) - wpos_ids = ( - torch.arange(num_patches_width, device=pixel_values.device).unsqueeze(0).expand(num_patches_height, -1) - ) - pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1) - rotary_pos_emb = self.vision_rotary_embedding(pos_ids) + rotary_pos_emb = self.vision_rotary_embedding(num_patches_height, num_patches_width) + rotary_pos_emb = rotary_pos_emb.to(self.class_pos_emb.device) rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -379,6 +379,42 @@ def forward( ) +class MLCDVisionModel(CLIPVisionModel): + def forward( + self, + pixel_values: torch.FloatTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + Example: + + ```python + >>> import httpx + >>> from io import BytesIO + >>> from PIL import Image + >>> from transformers import AutoProcessor, MLCDVisionModel + >>> model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448") + >>> processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> with httpx.stream("GET", url) as response: + ... image = Image.open(BytesIO(response.read())) + >>> inputs = processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs, output_attentions=True) + + >>> features = outputs.last_hidden_state + >>> print(f"Extracted features shape: {features.shape}") + >>> print(f"Number of attention layers: {len(outputs.attentions)}") + >>> print(f"Attention shape: {outputs.attentions[0].shape}") + ```""" + return self.vision_model( + pixel_values=pixel_values, + **kwargs, + ) + + __all__ = [ "MLCDVisionConfig", "MLCDPreTrainedModel", From e62aa98471d88b30d98cdfcb5021fe036ec8d99a Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 10:58:09 +0200 Subject: [PATCH 04/56] fix --- src/transformers/models/glm46v/modeling_glm46v.py | 12 ++++++------ .../paddleocr_vl/configuration_paddleocr_vl.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/glm46v/modeling_glm46v.py b/src/transformers/models/glm46v/modeling_glm46v.py index 11e4849405c9..4d20344dd4b8 100644 --- a/src/transformers/models/glm46v/modeling_glm46v.py +++ b/src/transformers/models/glm46v/modeling_glm46v.py @@ -292,12 +292,12 @@ def get_video_features( """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames - temp_frames_hw = [] - video_grid_thw_list = video_grid_thw.tolist() - for t, h, w in video_grid_thw_list: - repeated_row = torch.tensor([1, h, w]).unsqueeze(0).repeat(t, 1) - temp_frames_hw.append(repeated_row) - flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) + t = video_grid_thw[:, 0] + hw = video_grid_thw[:, 1:] + # repeat each (h,w) row `t` times + flattened_hw = torch.repeat_interleave(hw, t, dim=0) + prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) + flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) vision_outputs = self.visual( pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs ) diff --git a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py index 527f0fee9ea1..ebdb6897b6c4 100644 --- a/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/configuration_paddleocr_vl.py @@ -167,7 +167,7 @@ class PaddleOCRVLConfig(PreTrainedConfig): video_token_id: int = 100296 vision_start_token_id: int = 101305 vision_end_token_id: int = 101306 - tie_word_embeddings: int = True + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.vision_config, dict): From c1d7a8a515452739e33de29259b346bdd5d3dbfb Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 11:20:39 +0200 Subject: [PATCH 05/56] kwargs --- .../ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py | 8 ++++++-- .../ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py | 8 ++++++-- src/transformers/models/glm4v/modeling_glm4v.py | 6 ++++-- src/transformers/models/glm4v/modular_glm4v.py | 3 ++- .../models/glm4v_moe/modeling_glm4v_moe.py | 6 ++++-- .../models/glm_image/modeling_glm_image.py | 3 ++- .../models/glm_image/modular_glm_image.py | 3 ++- src/transformers/models/glm_ocr/modeling_glm_ocr.py | 6 ++++-- .../models/qwen2_5_omni/modeling_qwen2_5_omni.py | 12 ++++++------ .../models/qwen2_5_omni/modular_qwen2_5_omni.py | 12 ++++++------ .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 6 ++++-- .../models/qwen2_vl/modeling_qwen2_vl.py | 6 ++++-- .../models/qwen2_vl/processing_qwen2_vl.py | 11 +++++++++++ src/transformers/models/qwen3_5/modeling_qwen3_5.py | 3 ++- src/transformers/models/qwen3_5/modular_qwen3_5.py | 3 ++- .../models/qwen3_5_moe/modeling_qwen3_5_moe.py | 3 ++- .../models/qwen3_omni_moe/modeling_qwen3_omni_moe.py | 8 ++++++-- .../models/qwen3_omni_moe/modular_qwen3_omni_moe.py | 12 ++++++------ .../models/qwen3_vl/modeling_qwen3_vl.py | 3 ++- src/transformers/models/qwen3_vl/modular_qwen3_vl.py | 3 ++- .../models/qwen3_vl_moe/modeling_qwen3_vl_moe.py | 3 ++- .../models/video_llama_3/modeling_video_llama_3.py | 2 ++ .../models/video_llama_3/modular_video_llama_3.py | 2 ++ src/transformers/utils/generic.py | 7 +++++++ 24 files changed, 96 insertions(+), 43 deletions(-) diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index c5580476b78d..debd32626ffc 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -1251,7 +1251,10 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - video_outputs = self.vision_tower(pixel_values_videos, video_grid_thw, return_dict=True, **kwargs) + video_kwargs = kwargs.pop("video_kwargs", None) or {} + video_outputs = self.vision_tower( + pixel_values_videos, video_grid_thw, return_dict=True, **video_kwargs, **kwargs + ) video_embeds = self.resampler_model(video_outputs.last_hidden_state, video_grid_thw) split_sizes = ( video_grid_thw.prod(-1) @@ -1276,7 +1279,8 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_outputs = self.vision_tower(pixel_values, image_grid_thw, return_dict=True, **kwargs) + image_kwargs = kwargs.pop("image_kwargs", None) or {} + image_outputs = self.vision_tower(pixel_values, image_grid_thw, return_dict=True, **image_kwargs, **kwargs) image_embeds = self.resampler_model(image_outputs.last_hidden_state, image_grid_thw) split_sizes = (image_grid_thw.prod(-1) // self.vision_tower.spatial_merge_size**2).tolist() image_embeds = torch.split(image_embeds, split_sizes) diff --git a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py index 985518cfedc5..925eccd581bc 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py @@ -972,7 +972,10 @@ def get_video_features( video_grid_thw: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: - video_outputs = self.vision_tower(pixel_values_videos, video_grid_thw, return_dict=True, **kwargs) + video_kwargs = kwargs.pop("video_kwargs", None) or {} + video_outputs = self.vision_tower( + pixel_values_videos, video_grid_thw, return_dict=True, **video_kwargs, **kwargs + ) video_embeds = self.resampler_model(video_outputs.last_hidden_state, video_grid_thw) split_sizes = ( video_grid_thw.prod(-1) @@ -991,7 +994,8 @@ def get_image_features( image_grid_thw: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: - image_outputs = self.vision_tower(pixel_values, image_grid_thw, return_dict=True, **kwargs) + image_kwargs = kwargs.pop("image_kwargs", None) or {} + image_outputs = self.vision_tower(pixel_values, image_grid_thw, return_dict=True, **image_kwargs, **kwargs) image_embeds = self.resampler_model(image_outputs.last_hidden_state, image_grid_thw) split_sizes = (image_grid_thw.prod(-1) // self.vision_tower.spatial_merge_size**2).tolist() image_embeds = torch.split(image_embeds, split_sizes) diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 57e3481605e2..7f0df1fd6fc9 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -1109,6 +1109,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ + video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames t = video_grid_thw[:, 0] @@ -1118,7 +1119,7 @@ def get_video_features( prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) vision_outputs = self.visual( - pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs + pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **video_kwargs, **kwargs ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) @@ -1140,8 +1141,9 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ + image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **image_kwargs, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 1b70144ccf64..b4d2038771eb 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -797,6 +797,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ + video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames t = video_grid_thw[:, 0] @@ -806,7 +807,7 @@ def get_video_features( prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) vision_outputs = self.visual( - pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs + pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **video_kwargs, **kwargs ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index fa863b6e55ab..2e99db1ba348 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -1278,6 +1278,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ + video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames t = video_grid_thw[:, 0] @@ -1287,7 +1288,7 @@ def get_video_features( prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) vision_outputs = self.visual( - pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs + pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **video_kwargs, **kwargs ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) @@ -1309,8 +1310,9 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ + image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **image_kwargs, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py index b3d306e31a67..19d0c15d7695 100644 --- a/src/transformers/models/glm_image/modeling_glm_image.py +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -1178,8 +1178,9 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ + image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **image_kwargs, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.last_hidden_state, split_sizes) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index ba572d1485a8..fbe9a71985be 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -717,8 +717,9 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ + image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **image_kwargs, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.last_hidden_state, split_sizes) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/glm_ocr/modeling_glm_ocr.py b/src/transformers/models/glm_ocr/modeling_glm_ocr.py index 6515d5e21e2b..12eba3641b9f 100644 --- a/src/transformers/models/glm_ocr/modeling_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modeling_glm_ocr.py @@ -1026,6 +1026,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ + video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames t = video_grid_thw[:, 0] @@ -1035,7 +1036,7 @@ def get_video_features( prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) vision_outputs = self.visual( - pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs + pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **video_kwargs, **kwargs ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) @@ -1057,8 +1058,9 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ + image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **image_kwargs, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 2a517b346130..aa4472152b25 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1718,8 +1718,9 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ + video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **video_kwargs, **kwargs) @can_return_tuple @auto_docstring @@ -1735,8 +1736,9 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ + image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - return self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + return self.visual(pixel_values, grid_thw=image_grid_thw, **image_kwargs, **kwargs) @can_return_tuple @auto_docstring @@ -1765,11 +1767,9 @@ def get_audio_features( audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) ) feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + audio_kwargs = kwargs.pop("audio_kwargs", None) or {} audio_outputs = self.audio_tower( - input_features, - feature_lens=feature_lens, - return_dict=True, - **kwargs, + input_features, feature_lens=feature_lens, return_dict=True, **audio_kwargs, **kwargs ) if audio_outputs.last_hidden_state.shape[0] != sum(audio_output_lengths.tolist()): raise ValueError("length of audio_features should match audio_output_lengths") diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index dec2b77cf06e..becfd5f70f44 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1728,8 +1728,9 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ + video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **video_kwargs, **kwargs) @can_return_tuple @auto_docstring @@ -1745,8 +1746,9 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ + image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - return self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + return self.visual(pixel_values, grid_thw=image_grid_thw, **image_kwargs, **kwargs) @can_return_tuple @auto_docstring @@ -1775,11 +1777,9 @@ def get_audio_features( audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) ) feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + audio_kwargs = kwargs.pop("audio_kwargs", None) or {} audio_outputs = self.audio_tower( - input_features, - feature_lens=feature_lens, - return_dict=True, - **kwargs, + input_features, feature_lens=feature_lens, return_dict=True, **audio_kwargs, **kwargs ) if audio_outputs.last_hidden_state.shape[0] != sum(audio_output_lengths.tolist()): raise ValueError("length of audio_features should match audio_output_lengths") diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index ecdde65ae4b8..64e0f7e1d961 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1089,8 +1089,9 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ + video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values_videos, grid_thw=video_grid_thw, **video_kwargs, **kwargs) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds @@ -1111,8 +1112,9 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ + image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **image_kwargs, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 576ff9b206ff..d2c4317cf95e 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1078,8 +1078,9 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ + video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values_videos, grid_thw=video_grid_thw, **video_kwargs, **kwargs) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds @@ -1100,8 +1101,9 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ + image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **image_kwargs, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py index 9c38451e60e8..83a41c20978c 100644 --- a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py @@ -22,6 +22,7 @@ from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging @@ -88,10 +89,20 @@ def __call__( if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] + spatial_merge_size = self.image_processor.merge_size + image_inputs["image_kwargs"] = { + "cu_seqlens": get_vision_cu_seqlens(image_grid_thw), + "rotary_pos_ids": get_rotary_pos_ids(image_grid_thw, spatial_merge_size), + } if videos is not None: videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] + spatial_merge_size = self.video_processor.merge_size + videos_inputs["video_kwargs"] = { + "cu_seqlens": get_vision_cu_seqlens(video_grid_thw), + "rotary_pos_ids": get_rotary_pos_ids(video_grid_thw, spatial_merge_size), + } if not isinstance(text, list): text = [text] diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index f480e8e50c7c..4c6dccc40d4d 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1421,9 +1421,10 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ + image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithPooling = self.visual( - pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs + pixel_values, grid_thw=image_grid_thw, return_dict=True, **image_kwargs, **kwargs ) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index 1e075af3404a..9262a3fb4f0d 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -582,9 +582,10 @@ def get_image_features( image_grid_thw: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: + image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithPooling = self.visual( - pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs + pixel_values, grid_thw=image_grid_thw, return_dict=True, **image_kwargs, **kwargs ) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index cada7aaf0b3a..e551ea9131df 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1546,9 +1546,10 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ + image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithPooling = self.visual( - pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs + pixel_values, grid_thw=image_grid_thw, return_dict=True, **image_kwargs, **kwargs ) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 5b09794e08ee..3f36088678c1 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1877,8 +1877,9 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ + video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **video_kwargs, **kwargs) @can_return_tuple @auto_docstring @@ -1894,8 +1895,9 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ + image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - return self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + return self.visual(pixel_values, grid_thw=image_grid_thw, **image_kwargs, **kwargs) @can_return_tuple @auto_docstring @@ -1921,10 +1923,12 @@ def get_audio_features( audio_feature_lengths = None feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + audio_kwargs = kwargs.pop("audio_kwargs", None) or {} audio_outputs = self.audio_tower( input_features, feature_lens=feature_lens, return_dict=True, + **audio_kwargs, **kwargs, ) diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index 55a0621bc68d..920e1a51703c 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -1210,8 +1210,9 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ + video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **video_kwargs, **kwargs) @can_return_tuple @auto_docstring @@ -1227,8 +1228,9 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ + image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - return self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + return self.visual(pixel_values, grid_thw=image_grid_thw, **image_kwargs, **kwargs) @can_return_tuple @auto_docstring @@ -1254,11 +1256,9 @@ def get_audio_features( audio_feature_lengths = None feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + audio_kwargs = kwargs.pop("audio_kwargs", None) or {} audio_outputs = self.audio_tower( - input_features, - feature_lens=feature_lens, - return_dict=True, - **kwargs, + input_features, feature_lens=feature_lens, return_dict=True, **audio_kwargs, **kwargs ) return audio_outputs diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 7808d7e520b1..ec70fe060ea0 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -1063,9 +1063,10 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ + image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithDeepstackFeatures = self.visual( - pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs + pixel_values, grid_thw=image_grid_thw, return_dict=True, **image_kwargs, **kwargs ) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index df69111d99d7..2ffc06556163 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -698,9 +698,10 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ + image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithDeepstackFeatures = self.visual( - pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs + pixel_values, grid_thw=image_grid_thw, return_dict=True, **image_kwargs, **kwargs ) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index ae80e02edd8a..6d54482d5599 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -1193,9 +1193,10 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ + image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithDeepstackFeatures = self.visual( - pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs + pixel_values, grid_thw=image_grid_thw, return_dict=True, **image_kwargs, **kwargs ) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index 338ddfdeeb9c..53aa2ff32b73 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -561,11 +561,13 @@ def get_image_features( image_merge_sizes (`torch.Tensor` of shape `(num_images,)`): The spatial downsampling ratio of each image feature. """ + image_kwargs = kwargs.pop("image_kwargs", None) or {} vision_outputs = self.vision_model( pixel_values=pixel_values, grid_thw=image_grid_thw, merge_sizes=image_merge_sizes, return_dict=True, + **image_kwargs, **kwargs, ) last_hidden_state = vision_outputs.last_hidden_state diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index 1aad0f3daefa..6eab6fc9cff7 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -522,11 +522,13 @@ def get_image_features( image_merge_sizes (`torch.Tensor` of shape `(num_images,)`): The spatial downsampling ratio of each image feature. """ + image_kwargs = kwargs.pop("image_kwargs", None) or {} vision_outputs = self.vision_model( pixel_values=pixel_values, grid_thw=image_grid_thw, merge_sizes=image_merge_sizes, return_dict=True, + **image_kwargs, **kwargs, ) last_hidden_state = vision_outputs.last_hidden_state diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 5228ede6dd76..646066319437 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -792,6 +792,10 @@ class TransformersKwargs(TypedDict, total=False): Indices of positions of each input sequence tokens. is_causal (`bool`, *optional*) Can be set to False to enable bi-directional attention, i.e. use decoder Attention modules as encoders. + image_kwargs (`dict`, *optional*): + Precomputed vision tensors for images (from the processor), passed to the vision encoder. + video_kwargs (`dict`, *optional*): + Precomputed vision tensors for videos (from the processor), passed to the vision encoder. """ num_items_in_batch: torch.Tensor | None @@ -804,6 +808,9 @@ class TransformersKwargs(TypedDict, total=False): max_length_k: int | None position_ids: torch.LongTensor | None is_causal: bool | None + image_kwargs: dict | None + video_kwargs: dict | None + audio_kwargs: dict | None def is_timm_config_dict(config_dict: dict[str, Any]) -> bool: From 2771799642590bb1c57f5c4f46840d73d6251922 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 11:42:38 +0200 Subject: [PATCH 06/56] opt-in --- .../modeling_ernie4_5_vl_moe.py | 18 ++++++++++--- .../modular_ernie4_5_vl_moe.py | 18 ++++++++++--- .../models/glm46v/modeling_glm46v.py | 15 +++++++++-- .../models/glm46v/processing_glm46v.py | 10 +++++++ .../models/glm4v/modeling_glm4v.py | 17 +++++++++--- .../models/glm4v/modular_glm4v.py | 17 ++++++++++-- .../models/glm4v/processing_glm4v.py | 10 +++++++ .../models/glm4v_moe/modeling_glm4v_moe.py | 17 +++++++++--- .../models/glm_image/modeling_glm_image.py | 10 +++++-- .../models/glm_image/modular_glm_image.py | 10 +++++-- .../models/glm_ocr/modeling_glm_ocr.py | 17 +++++++++--- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 23 ++++++++++------ .../qwen2_5_omni/modular_qwen2_5_omni.py | 23 ++++++++++------ .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 18 ++++++++++--- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 10 +++++++ .../qwen2_5_vl/processing_qwen2_5_vl.py | 11 ++++++++ .../models/qwen2_vl/modeling_qwen2_vl.py | 18 ++++++++++--- .../models/qwen2_vl/processing_qwen2_vl.py | 20 +++++++------- .../models/qwen3_5/modeling_qwen3_5.py | 8 ++++-- .../models/qwen3_5/modular_qwen3_5.py | 8 ++++-- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 8 ++++-- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 27 ++++++++++--------- .../qwen3_omni_moe/modular_qwen3_omni_moe.py | 23 ++++++++++------ .../models/qwen3_vl/modeling_qwen3_vl.py | 8 ++++-- .../models/qwen3_vl/modular_qwen3_vl.py | 17 ++++++++++-- .../models/qwen3_vl/processing_qwen3_vl.py | 11 ++++++++ .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 8 ++++-- .../video_llama_3/modeling_video_llama_3.py | 4 +-- .../video_llama_3/modular_video_llama_3.py | 9 +++++-- .../video_llama_3/processing_video_llama_3.py | 6 +++++ src/transformers/utils/generic.py | 7 ----- 31 files changed, 321 insertions(+), 105 deletions(-) diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index debd32626ffc..bc0c25339416 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -1251,9 +1251,13 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - video_kwargs = kwargs.pop("video_kwargs", None) or {} video_outputs = self.vision_tower( - pixel_values_videos, video_grid_thw, return_dict=True, **video_kwargs, **kwargs + pixel_values_videos, + video_grid_thw, + cu_seqlens=kwargs.pop("video_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + return_dict=True, + **kwargs, ) video_embeds = self.resampler_model(video_outputs.last_hidden_state, video_grid_thw) split_sizes = ( @@ -1279,8 +1283,14 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_kwargs = kwargs.pop("image_kwargs", None) or {} - image_outputs = self.vision_tower(pixel_values, image_grid_thw, return_dict=True, **image_kwargs, **kwargs) + image_outputs = self.vision_tower( + pixel_values, + image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + return_dict=True, + **kwargs, + ) image_embeds = self.resampler_model(image_outputs.last_hidden_state, image_grid_thw) split_sizes = (image_grid_thw.prod(-1) // self.vision_tower.spatial_merge_size**2).tolist() image_embeds = torch.split(image_embeds, split_sizes) diff --git a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py index 925eccd581bc..5c2269efa7ac 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py @@ -972,9 +972,13 @@ def get_video_features( video_grid_thw: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: - video_kwargs = kwargs.pop("video_kwargs", None) or {} video_outputs = self.vision_tower( - pixel_values_videos, video_grid_thw, return_dict=True, **video_kwargs, **kwargs + pixel_values_videos, + video_grid_thw, + cu_seqlens=kwargs.pop("video_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + return_dict=True, + **kwargs, ) video_embeds = self.resampler_model(video_outputs.last_hidden_state, video_grid_thw) split_sizes = ( @@ -994,8 +998,14 @@ def get_image_features( image_grid_thw: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: - image_kwargs = kwargs.pop("image_kwargs", None) or {} - image_outputs = self.vision_tower(pixel_values, image_grid_thw, return_dict=True, **image_kwargs, **kwargs) + image_outputs = self.vision_tower( + pixel_values, + image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + return_dict=True, + **kwargs, + ) image_embeds = self.resampler_model(image_outputs.last_hidden_state, image_grid_thw) split_sizes = (image_grid_thw.prod(-1) // self.vision_tower.spatial_merge_size**2).tolist() image_embeds = torch.split(image_embeds, split_sizes) diff --git a/src/transformers/models/glm46v/modeling_glm46v.py b/src/transformers/models/glm46v/modeling_glm46v.py index 4d20344dd4b8..3d65679d6a2a 100644 --- a/src/transformers/models/glm46v/modeling_glm46v.py +++ b/src/transformers/models/glm46v/modeling_glm46v.py @@ -299,7 +299,12 @@ def get_video_features( prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) vision_outputs = self.visual( - pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs + pixel_values_videos, + grid_thw=flattened_video_grid_thw, + cu_seqlens=kwargs.pop("video_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + return_dict=True, + **kwargs, ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) @@ -322,7 +327,13 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual( + pixel_values, + grid_thw=image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + **kwargs, + ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/glm46v/processing_glm46v.py b/src/transformers/models/glm46v/processing_glm46v.py index 9dcf7c4856e6..7687d55478b0 100644 --- a/src/transformers/models/glm46v/processing_glm46v.py +++ b/src/transformers/models/glm46v/processing_glm46v.py @@ -23,6 +23,7 @@ from ...image_processing_utils import BatchFeature from ...image_utils import ImageInput +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging @@ -83,6 +84,7 @@ def __call__( - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. """ + return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( Glm46VProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, @@ -91,6 +93,10 @@ def __call__( if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] + if return_extra_tensors: + spatial_merge_size = self.image_processor.merge_size + image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) + image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) else: image_inputs = {} image_grid_thw = None @@ -103,6 +109,10 @@ def __call__( else: video_metadata = videos_inputs["video_metadata"] video_grid_thw = videos_inputs["video_grid_thw"] + if return_extra_tensors: + spatial_merge_size = self.video_processor.merge_size + videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) + videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) else: videos_inputs = {} video_grid_thw = None diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 7f0df1fd6fc9..0af53440dba1 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -1109,7 +1109,6 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames t = video_grid_thw[:, 0] @@ -1119,7 +1118,12 @@ def get_video_features( prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) vision_outputs = self.visual( - pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **video_kwargs, **kwargs + pixel_values_videos, + grid_thw=flattened_video_grid_thw, + cu_seqlens=kwargs.pop("video_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + return_dict=True, + **kwargs, ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) @@ -1141,9 +1145,14 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **image_kwargs, **kwargs) + vision_outputs = self.visual( + pixel_values, + grid_thw=image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + **kwargs, + ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index b4d2038771eb..88de6ee1828b 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -797,7 +797,6 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames t = video_grid_thw[:, 0] @@ -807,7 +806,12 @@ def get_video_features( prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) vision_outputs = self.visual( - pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **video_kwargs, **kwargs + pixel_values_videos, + grid_thw=flattened_video_grid_thw, + cu_seqlens=kwargs.pop("video_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + return_dict=True, + **kwargs, ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) @@ -1273,6 +1277,7 @@ def __call__( - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. """ + return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( Glm4vProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, @@ -1281,6 +1286,10 @@ def __call__( if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] + if return_extra_tensors: + spatial_merge_size = self.image_processor.merge_size + image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) + image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) else: image_inputs = {} image_grid_thw = None @@ -1293,6 +1302,10 @@ def __call__( else: video_metadata = videos_inputs["video_metadata"] video_grid_thw = videos_inputs["video_grid_thw"] + if return_extra_tensors: + spatial_merge_size = self.video_processor.merge_size + videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) + videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) else: videos_inputs = {} video_grid_thw = None diff --git a/src/transformers/models/glm4v/processing_glm4v.py b/src/transformers/models/glm4v/processing_glm4v.py index 2d3e93aec9ed..eaf8a5d90e1b 100644 --- a/src/transformers/models/glm4v/processing_glm4v.py +++ b/src/transformers/models/glm4v/processing_glm4v.py @@ -22,6 +22,7 @@ from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging @@ -82,6 +83,7 @@ def __call__( - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. """ + return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( Glm4vProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, @@ -90,6 +92,10 @@ def __call__( if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] + if return_extra_tensors: + spatial_merge_size = self.image_processor.merge_size + image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) + image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) else: image_inputs = {} image_grid_thw = None @@ -102,6 +108,10 @@ def __call__( else: video_metadata = videos_inputs["video_metadata"] video_grid_thw = videos_inputs["video_grid_thw"] + if return_extra_tensors: + spatial_merge_size = self.video_processor.merge_size + videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) + videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) else: videos_inputs = {} video_grid_thw = None diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 2e99db1ba348..d53fe3922ae5 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -1278,7 +1278,6 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames t = video_grid_thw[:, 0] @@ -1288,7 +1287,12 @@ def get_video_features( prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) vision_outputs = self.visual( - pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **video_kwargs, **kwargs + pixel_values_videos, + grid_thw=flattened_video_grid_thw, + cu_seqlens=kwargs.pop("video_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + return_dict=True, + **kwargs, ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) @@ -1310,9 +1314,14 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **image_kwargs, **kwargs) + vision_outputs = self.visual( + pixel_values, + grid_thw=image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + **kwargs, + ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py index 19d0c15d7695..63da785d703f 100644 --- a/src/transformers/models/glm_image/modeling_glm_image.py +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -1178,9 +1178,15 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **image_kwargs, **kwargs) + vision_outputs = self.visual( + pixel_values, + grid_thw=image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + return_dict=True, + **kwargs, + ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.last_hidden_state, split_sizes) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index fbe9a71985be..ee5a4cbc23d6 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -717,9 +717,15 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **image_kwargs, **kwargs) + vision_outputs = self.visual( + pixel_values, + grid_thw=image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + return_dict=True, + **kwargs, + ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.last_hidden_state, split_sizes) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/glm_ocr/modeling_glm_ocr.py b/src/transformers/models/glm_ocr/modeling_glm_ocr.py index 12eba3641b9f..3a8d6dd6634e 100644 --- a/src/transformers/models/glm_ocr/modeling_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modeling_glm_ocr.py @@ -1026,7 +1026,6 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames t = video_grid_thw[:, 0] @@ -1036,7 +1035,12 @@ def get_video_features( prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) vision_outputs = self.visual( - pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **video_kwargs, **kwargs + pixel_values_videos, + grid_thw=flattened_video_grid_thw, + cu_seqlens=kwargs.pop("video_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + return_dict=True, + **kwargs, ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) @@ -1058,9 +1062,14 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **image_kwargs, **kwargs) + vision_outputs = self.visual( + pixel_values, + grid_thw=image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + **kwargs, + ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index aa4472152b25..8275fe0324cb 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1718,9 +1718,14 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **video_kwargs, **kwargs) + return self.visual( + pixel_values_videos, + grid_thw=video_grid_thw, + cu_seqlens=kwargs.pop("video_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + **kwargs, + ) @can_return_tuple @auto_docstring @@ -1736,9 +1741,14 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - return self.visual(pixel_values, grid_thw=image_grid_thw, **image_kwargs, **kwargs) + return self.visual( + pixel_values, + grid_thw=image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + **kwargs, + ) @can_return_tuple @auto_docstring @@ -1767,10 +1777,7 @@ def get_audio_features( audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) ) feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) - audio_kwargs = kwargs.pop("audio_kwargs", None) or {} - audio_outputs = self.audio_tower( - input_features, feature_lens=feature_lens, return_dict=True, **audio_kwargs, **kwargs - ) + audio_outputs = self.audio_tower(input_features, feature_lens=feature_lens, return_dict=True, **kwargs) if audio_outputs.last_hidden_state.shape[0] != sum(audio_output_lengths.tolist()): raise ValueError("length of audio_features should match audio_output_lengths") diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index becfd5f70f44..44f7ffa33ef2 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1728,9 +1728,14 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **video_kwargs, **kwargs) + return self.visual( + pixel_values_videos, + grid_thw=video_grid_thw, + cu_seqlens=kwargs.pop("video_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + **kwargs, + ) @can_return_tuple @auto_docstring @@ -1746,9 +1751,14 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - return self.visual(pixel_values, grid_thw=image_grid_thw, **image_kwargs, **kwargs) + return self.visual( + pixel_values, + grid_thw=image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + **kwargs, + ) @can_return_tuple @auto_docstring @@ -1777,10 +1787,7 @@ def get_audio_features( audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) ) feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) - audio_kwargs = kwargs.pop("audio_kwargs", None) or {} - audio_outputs = self.audio_tower( - input_features, feature_lens=feature_lens, return_dict=True, **audio_kwargs, **kwargs - ) + audio_outputs = self.audio_tower(input_features, feature_lens=feature_lens, return_dict=True, **kwargs) if audio_outputs.last_hidden_state.shape[0] != sum(audio_output_lengths.tolist()): raise ValueError("length of audio_features should match audio_output_lengths") diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 64e0f7e1d961..d98dd94bd770 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1089,9 +1089,14 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values_videos, grid_thw=video_grid_thw, **video_kwargs, **kwargs) + vision_outputs = self.visual( + pixel_values_videos, + grid_thw=video_grid_thw, + cu_seqlens=kwargs.pop("video_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + **kwargs, + ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds @@ -1112,9 +1117,14 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **image_kwargs, **kwargs) + vision_outputs = self.visual( + pixel_values, + grid_thw=image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + **kwargs, + ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 3d489cd6f66e..3454a517befd 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -762,14 +762,24 @@ def __call__( **kwargs, ) + return_extra_tensors = kwargs.pop("return_extra_tensors", False) + image_inputs = videos_inputs = {} if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] + if return_extra_tensors: + spatial_merge_size = self.image_processor.merge_size + image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) + image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) if videos is not None: videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] + if return_extra_tensors: + spatial_merge_size = self.video_processor.merge_size + videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) + videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) # Get video metadata if not kwargs.get("return_metadata"): diff --git a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py index 8873eb82557a..4cd6af8e75fd 100644 --- a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py @@ -24,6 +24,7 @@ # limitations under the License. from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring @@ -85,14 +86,24 @@ def __call__( **kwargs, ) + return_extra_tensors = kwargs.pop("return_extra_tensors", False) + image_inputs = videos_inputs = {} if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] + if return_extra_tensors: + spatial_merge_size = self.image_processor.merge_size + image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) + image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) if videos is not None: videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] + if return_extra_tensors: + spatial_merge_size = self.video_processor.merge_size + videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) + videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) # Get video metadata if not kwargs.get("return_metadata"): diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index d2c4317cf95e..67ff370e5fd5 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1078,9 +1078,14 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values_videos, grid_thw=video_grid_thw, **video_kwargs, **kwargs) + vision_outputs = self.visual( + pixel_values_videos, + grid_thw=video_grid_thw, + cu_seqlens=kwargs.pop("video_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + **kwargs, + ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds @@ -1101,9 +1106,14 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **image_kwargs, **kwargs) + vision_outputs = self.visual( + pixel_values, + grid_thw=image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + **kwargs, + ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py index 83a41c20978c..768af8c352cb 100644 --- a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py @@ -85,24 +85,24 @@ def __call__( **kwargs, ) + return_extra_tensors = kwargs.pop("return_extra_tensors", False) + image_inputs = videos_inputs = {} if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] - spatial_merge_size = self.image_processor.merge_size - image_inputs["image_kwargs"] = { - "cu_seqlens": get_vision_cu_seqlens(image_grid_thw), - "rotary_pos_ids": get_rotary_pos_ids(image_grid_thw, spatial_merge_size), - } + if return_extra_tensors: + spatial_merge_size = self.image_processor.merge_size + image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) + image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) if videos is not None: videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] - spatial_merge_size = self.video_processor.merge_size - videos_inputs["video_kwargs"] = { - "cu_seqlens": get_vision_cu_seqlens(video_grid_thw), - "rotary_pos_ids": get_rotary_pos_ids(video_grid_thw, spatial_merge_size), - } + if return_extra_tensors: + spatial_merge_size = self.video_processor.merge_size + videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) + videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) if not isinstance(text, list): text = [text] diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index 4c6dccc40d4d..b2c2fc6550b3 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1421,10 +1421,14 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithPooling = self.visual( - pixel_values, grid_thw=image_grid_thw, return_dict=True, **image_kwargs, **kwargs + pixel_values, + grid_thw=image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + return_dict=True, + **kwargs, ) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index 9262a3fb4f0d..8977037115b5 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -582,10 +582,14 @@ def get_image_features( image_grid_thw: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: - image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithPooling = self.visual( - pixel_values, grid_thw=image_grid_thw, return_dict=True, **image_kwargs, **kwargs + pixel_values, + grid_thw=image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + return_dict=True, + **kwargs, ) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index e551ea9131df..512eef74c3bc 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1546,10 +1546,14 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithPooling = self.visual( - pixel_values, grid_thw=image_grid_thw, return_dict=True, **image_kwargs, **kwargs + pixel_values, + grid_thw=image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + return_dict=True, + **kwargs, ) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 3f36088678c1..d1642d29b5b4 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1877,9 +1877,14 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **video_kwargs, **kwargs) + return self.visual( + pixel_values_videos, + grid_thw=video_grid_thw, + cu_seqlens=kwargs.pop("video_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + **kwargs, + ) @can_return_tuple @auto_docstring @@ -1895,9 +1900,14 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - return self.visual(pixel_values, grid_thw=image_grid_thw, **image_kwargs, **kwargs) + return self.visual( + pixel_values, + grid_thw=image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + **kwargs, + ) @can_return_tuple @auto_docstring @@ -1923,14 +1933,7 @@ def get_audio_features( audio_feature_lengths = None feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) - audio_kwargs = kwargs.pop("audio_kwargs", None) or {} - audio_outputs = self.audio_tower( - input_features, - feature_lens=feature_lens, - return_dict=True, - **audio_kwargs, - **kwargs, - ) + audio_outputs = self.audio_tower(input_features, feature_lens=feature_lens, return_dict=True, **kwargs) return audio_outputs diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index 920e1a51703c..7d0acfd8bd12 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -1210,9 +1210,14 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - video_kwargs = kwargs.pop("video_kwargs", None) or {} pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **video_kwargs, **kwargs) + return self.visual( + pixel_values_videos, + grid_thw=video_grid_thw, + cu_seqlens=kwargs.pop("video_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + **kwargs, + ) @can_return_tuple @auto_docstring @@ -1228,9 +1233,14 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) - return self.visual(pixel_values, grid_thw=image_grid_thw, **image_kwargs, **kwargs) + return self.visual( + pixel_values, + grid_thw=image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + **kwargs, + ) @can_return_tuple @auto_docstring @@ -1256,10 +1266,7 @@ def get_audio_features( audio_feature_lengths = None feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) - audio_kwargs = kwargs.pop("audio_kwargs", None) or {} - audio_outputs = self.audio_tower( - input_features, feature_lens=feature_lens, return_dict=True, **audio_kwargs, **kwargs - ) + audio_outputs = self.audio_tower(input_features, feature_lens=feature_lens, return_dict=True, **kwargs) return audio_outputs diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index ec70fe060ea0..dc0e29138669 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -1063,10 +1063,14 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithDeepstackFeatures = self.visual( - pixel_values, grid_thw=image_grid_thw, return_dict=True, **image_kwargs, **kwargs + pixel_values, + grid_thw=image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + return_dict=True, + **kwargs, ) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 2ffc06556163..f3d4e21a26fd 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -698,10 +698,14 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithDeepstackFeatures = self.visual( - pixel_values, grid_thw=image_grid_thw, return_dict=True, **image_kwargs, **kwargs + pixel_values, + grid_thw=image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + return_dict=True, + **kwargs, ) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -1132,6 +1136,7 @@ def __call__( - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. """ + return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( Qwen3VLProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, @@ -1140,6 +1145,10 @@ def __call__( if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] + if return_extra_tensors: + spatial_merge_size = self.image_processor.merge_size + image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) + image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) else: image_inputs = {} image_grid_thw = None @@ -1147,6 +1156,10 @@ def __call__( if videos is not None: videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] + if return_extra_tensors: + spatial_merge_size = self.video_processor.merge_size + videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) + videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) # If user has not requested video metadata, pop it if not kwargs.get("return_metadata"): video_metadata = videos_inputs.pop("video_metadata") diff --git a/src/transformers/models/qwen3_vl/processing_qwen3_vl.py b/src/transformers/models/qwen3_vl/processing_qwen3_vl.py index 1ca435749ad2..c9098d9436b7 100644 --- a/src/transformers/models/qwen3_vl/processing_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/processing_qwen3_vl.py @@ -22,6 +22,8 @@ from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput +from ...modeling_vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids +from ...modeling_vision_utils import get_vision_cu_seqlens from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging @@ -96,6 +98,7 @@ def __call__( - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. """ + return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( Qwen3VLProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, @@ -104,6 +107,10 @@ def __call__( if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] + if return_extra_tensors: + spatial_merge_size = self.image_processor.merge_size + image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) + image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) else: image_inputs = {} image_grid_thw = None @@ -111,6 +118,10 @@ def __call__( if videos is not None: videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] + if return_extra_tensors: + spatial_merge_size = self.video_processor.merge_size + videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) + videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) # If user has not requested video metadata, pop it if not kwargs.get("return_metadata"): video_metadata = videos_inputs.pop("video_metadata") diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 6d54482d5599..ae5dfba0c5cf 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -1193,10 +1193,14 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_kwargs = kwargs.pop("image_kwargs", None) or {} pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithDeepstackFeatures = self.visual( - pixel_values, grid_thw=image_grid_thw, return_dict=True, **image_kwargs, **kwargs + pixel_values, + grid_thw=image_grid_thw, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + return_dict=True, + **kwargs, ) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index 53aa2ff32b73..cfc594a84f17 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -561,13 +561,13 @@ def get_image_features( image_merge_sizes (`torch.Tensor` of shape `(num_images,)`): The spatial downsampling ratio of each image feature. """ - image_kwargs = kwargs.pop("image_kwargs", None) or {} vision_outputs = self.vision_model( pixel_values=pixel_values, grid_thw=image_grid_thw, merge_sizes=image_merge_sizes, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), return_dict=True, - **image_kwargs, **kwargs, ) last_hidden_state = vision_outputs.last_hidden_state diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index 6eab6fc9cff7..7506c711a03c 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -522,13 +522,13 @@ def get_image_features( image_merge_sizes (`torch.Tensor` of shape `(num_images,)`): The spatial downsampling ratio of each image feature. """ - image_kwargs = kwargs.pop("image_kwargs", None) or {} vision_outputs = self.vision_model( pixel_values=pixel_values, grid_thw=image_grid_thw, merge_sizes=image_merge_sizes, + cu_seqlens=kwargs.pop("image_cu_seqlens", None), + rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), return_dict=True, - **image_kwargs, **kwargs, ) last_hidden_state = vision_outputs.last_hidden_state @@ -1007,11 +1007,16 @@ def __call__( **kwargs, ) + return_extra_tensors = kwargs.pop("return_extra_tensors", False) + image_inputs = videos_inputs = {} if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] image_merge_sizes = image_inputs["image_merge_sizes"] + if return_extra_tensors: + image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) + image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, image_merge_sizes) else: image_grid_thw = image_merge_sizes = [] diff --git a/src/transformers/models/video_llama_3/processing_video_llama_3.py b/src/transformers/models/video_llama_3/processing_video_llama_3.py index 7916d7e41d8e..321a5b46f769 100644 --- a/src/transformers/models/video_llama_3/processing_video_llama_3.py +++ b/src/transformers/models/video_llama_3/processing_video_llama_3.py @@ -19,6 +19,7 @@ # limitations under the License. from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging @@ -82,11 +83,16 @@ def __call__( **kwargs, ) + return_extra_tensors = kwargs.pop("return_extra_tensors", False) + image_inputs = videos_inputs = {} if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] image_merge_sizes = image_inputs["image_merge_sizes"] + if return_extra_tensors: + image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) + image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, image_merge_sizes) else: image_grid_thw = image_merge_sizes = [] diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 646066319437..5228ede6dd76 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -792,10 +792,6 @@ class TransformersKwargs(TypedDict, total=False): Indices of positions of each input sequence tokens. is_causal (`bool`, *optional*) Can be set to False to enable bi-directional attention, i.e. use decoder Attention modules as encoders. - image_kwargs (`dict`, *optional*): - Precomputed vision tensors for images (from the processor), passed to the vision encoder. - video_kwargs (`dict`, *optional*): - Precomputed vision tensors for videos (from the processor), passed to the vision encoder. """ num_items_in_batch: torch.Tensor | None @@ -808,9 +804,6 @@ class TransformersKwargs(TypedDict, total=False): max_length_k: int | None position_ids: torch.LongTensor | None is_causal: bool | None - image_kwargs: dict | None - video_kwargs: dict | None - audio_kwargs: dict | None def is_timm_config_dict(config_dict: dict[str, Any]) -> bool: From fa224e20de91deb559305e7d680dc060c6342d54 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 12:05:00 +0200 Subject: [PATCH 07/56] fix dtype --- src/transformers/models/qwen3_5/modeling_qwen3_5.py | 2 +- src/transformers/models/qwen3_5/modular_qwen3_5.py | 2 +- src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py | 2 +- .../models/qwen3_omni_moe/modeling_qwen3_omni_moe.py | 2 +- src/transformers/models/qwen3_vl/modeling_qwen3_vl.py | 2 +- src/transformers/models/qwen3_vl/modular_qwen3_vl.py | 2 +- src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index b2c2fc6550b3..d2bb3fb7ccd1 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1064,7 +1064,7 @@ def forward( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) - hidden_states = hidden_states + pos_embeds + hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index 8977037115b5..57c21714c8c8 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -457,7 +457,7 @@ def forward( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) - hidden_states = hidden_states + pos_embeds + hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 512eef74c3bc..db1fc3d961e6 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1157,7 +1157,7 @@ def forward( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) - hidden_states = hidden_states + pos_embeds + hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index d1642d29b5b4..2f87ed0f0174 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1154,7 +1154,7 @@ def forward( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) - hidden_states = hidden_states + pos_embeds + hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index dc0e29138669..1251c49462b8 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -695,7 +695,7 @@ def forward( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) - hidden_states = hidden_states + pos_embeds + hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index f3d4e21a26fd..eda65088474b 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -483,7 +483,7 @@ def forward( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) - hidden_states = hidden_states + pos_embeds + hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index ae5dfba0c5cf..008b48ae78b0 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -680,7 +680,7 @@ def forward( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) - hidden_states = hidden_states + pos_embeds + hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) From ac2895d1a3f756042f9c2fb911762acb9a053e94 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 12:13:33 +0200 Subject: [PATCH 08/56] style --- docs/source/en/model_doc/nomic_bert.md | 2 +- src/transformers/models/esm/configuration_esm.py | 4 ++-- src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py | 3 +-- src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py | 3 +-- src/transformers/models/qwen2_vl/processing_qwen2_vl.py | 3 +-- .../models/video_llama_3/modular_video_llama_3.py | 3 +-- .../models/video_llama_3/processing_video_llama_3.py | 3 +-- 7 files changed, 8 insertions(+), 13 deletions(-) diff --git a/docs/source/en/model_doc/nomic_bert.md b/docs/source/en/model_doc/nomic_bert.md index 73b3adc8a35f..2017805fe42a 100644 --- a/docs/source/en/model_doc/nomic_bert.md +++ b/docs/source/en/model_doc/nomic_bert.md @@ -23,7 +23,7 @@ limitations under the License. ## Overview -NomicBERT was proposed in [Nomic Embed: Training a Reproducible Long Context Text Embedder](https://arxiv.org/abs/2402.01613) by +NomicBERT was proposed in [Nomic Embed: Training a Reproducible Long Context Text Embedder](https://huggingface.co/papers/2402.01613) by Zach Nussbaum, John X. Morris, Brandon Duderstadt, and Andriy Mulyar. It is BERT-inspired with the most notable extension applying [Rotary Position Embeddings](https://huggingface.co/papers/2104.09864.pdf) to an encoder model. diff --git a/src/transformers/models/esm/configuration_esm.py b/src/transformers/models/esm/configuration_esm.py index a00dcf8b39e3..7875d88ecee8 100644 --- a/src/transformers/models/esm/configuration_esm.py +++ b/src/transformers/models/esm/configuration_esm.py @@ -159,12 +159,12 @@ class EsmConfig(PreTrainedConfig): mask_token_id (`int`, *optional*): The index of the mask token in the vocabulary. This must be included in the config because of the "mask-dropout" scaling trick, which will scale the inputs depending on the number of masked tokens. + rope_theta (`float`, defaults to 10000.0): + The base period of the RoPE embeddings. Only used when `position_embedding_type` is set to `"rotary"`. position_embedding_type (`str`, *optional*, defaults to `"absolute"`): Type of position embedding. Choose either `"absolute"` or "rotary"`. emb_layer_norm_before (`bool`, *optional*): Whether to apply layer normalization after embeddings but before the main stem of the network. - rope_theta (`float`, defaults to 10000.0): - The base period of the RoPE embeddings. Only used when `position_embedding_type` is set to `"rotary"`. token_dropout (`bool`, defaults to `False`): When this is enabled, masked tokens are treated as if they had been dropped out by input dropout. is_folding_model (`bool`, defaults to `False`): diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 3454a517befd..82d7ab4c1ee3 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -756,14 +756,13 @@ def __call__( - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. """ + return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( Qwen2_5_VLProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) - return_extra_tensors = kwargs.pop("return_extra_tensors", False) - image_inputs = videos_inputs = {} if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) diff --git a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py index 4cd6af8e75fd..f6785781be2c 100644 --- a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py @@ -80,14 +80,13 @@ def __call__( - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. """ + return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( Qwen2_5_VLProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) - return_extra_tensors = kwargs.pop("return_extra_tensors", False) - image_inputs = videos_inputs = {} if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) diff --git a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py index 768af8c352cb..2b8aa340e1f1 100644 --- a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py @@ -79,14 +79,13 @@ def __call__( - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. """ + return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( Qwen2VLProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) - return_extra_tensors = kwargs.pop("return_extra_tensors", False) - image_inputs = videos_inputs = {} if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index 7506c711a03c..a22c5321c048 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -1001,14 +1001,13 @@ def __call__( videos: VideoInput = None, **kwargs: Unpack[VideoLlama3ProcessorKwargs], ) -> BatchFeature: + return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( VideoLlama3ProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) - return_extra_tensors = kwargs.pop("return_extra_tensors", False) - image_inputs = videos_inputs = {} if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) diff --git a/src/transformers/models/video_llama_3/processing_video_llama_3.py b/src/transformers/models/video_llama_3/processing_video_llama_3.py index 321a5b46f769..987cc0069faf 100644 --- a/src/transformers/models/video_llama_3/processing_video_llama_3.py +++ b/src/transformers/models/video_llama_3/processing_video_llama_3.py @@ -77,14 +77,13 @@ def __call__( - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. """ + return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( VideoLlama3ProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) - return_extra_tensors = kwargs.pop("return_extra_tensors", False) - image_inputs = videos_inputs = {} if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) From 2f2787c9c395ba47bae506fa1e8e272aacc3d6c9 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 12:14:44 +0200 Subject: [PATCH 09/56] guard torch import --- src/transformers/modeling_vision_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_vision_utils.py b/src/transformers/modeling_vision_utils.py index 1fc7040f7553..84417033eb91 100644 --- a/src/transformers/modeling_vision_utils.py +++ b/src/transformers/modeling_vision_utils.py @@ -19,8 +19,14 @@ (``repeat_interleave``, ``.tolist()``, ``nonzero()``, loops). """ -import torch -import torch.nn.functional as F +from __future__ import annotations + +from .utils.import_utils import is_torch_available + + +if is_torch_available(): + import torch + import torch.nn.functional as F def get_vision_cu_seqlens(grid_thw: torch.Tensor) -> torch.Tensor: From d628d966aadfa42fc21869319ff36af6e2b44ad7 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 12:34:19 +0200 Subject: [PATCH 10/56] standarize --- .../modeling_ernie4_5_vl_moe.py | 38 ++++++++++-- .../modular_ernie4_5_vl_moe.py | 12 ++-- .../models/glm4v/modeling_glm4v.py | 34 +++++++++-- .../models/glm4v/modular_glm4v.py | 6 +- .../models/glm4v_moe/modeling_glm4v_moe.py | 34 +++++++++-- .../models/glm_image/modeling_glm_image.py | 6 +- .../models/glm_image/modular_glm_image.py | 6 +- .../models/glm_ocr/modeling_glm_ocr.py | 34 +++++++++-- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 12 ++-- .../qwen2_5_omni/modular_qwen2_5_omni.py | 12 ++-- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 56 +++++++++++++++--- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 12 +++- .../models/qwen2_vl/modeling_qwen2_vl.py | 58 ++++++++++++++++--- .../models/qwen3_5/modeling_qwen3_5.py | 34 +++++++++-- .../models/qwen3_5/modular_qwen3_5.py | 6 +- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 34 +++++++++-- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 12 ++-- .../qwen3_omni_moe/modular_qwen3_omni_moe.py | 12 ++-- .../models/qwen3_vl/modeling_qwen3_vl.py | 50 +++++++++++++--- .../models/qwen3_vl/modular_qwen3_vl.py | 20 +++++-- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 50 +++++++++++++--- .../video_llama_3/modeling_video_llama_3.py | 28 +++++++-- .../video_llama_3/modular_video_llama_3.py | 10 +++- 23 files changed, 471 insertions(+), 105 deletions(-) diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index bc0c25339416..16a7d2359e76 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -1243,6 +1243,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1250,12 +1252,16 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from the processor). + video_rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed rotary position IDs (from the processor). """ video_outputs = self.vision_tower( pixel_values_videos, video_grid_thw, - cu_seqlens=kwargs.pop("video_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + cu_seqlens=video_cu_seqlens, + rotary_pos_ids=video_rotary_pos_ids, return_dict=True, **kwargs, ) @@ -1275,6 +1281,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1282,12 +1290,16 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from the processor). + image_rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed rotary position IDs (from the processor). """ image_outputs = self.vision_tower( pixel_values, image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, return_dict=True, **kwargs, ) @@ -1578,6 +1590,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1587,7 +1601,11 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs + pixel_values_videos, + video_grid_thw, + video_cu_seqlens, + video_rotary_pos_ids, + **kwargs, ) @auto_docstring @@ -1595,6 +1613,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1603,7 +1623,13 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features( + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, + **kwargs, + ) @auto_docstring @can_return_tuple diff --git a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py index 5c2269efa7ac..b11d2426a652 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py @@ -970,13 +970,15 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: video_outputs = self.vision_tower( pixel_values_videos, video_grid_thw, - cu_seqlens=kwargs.pop("video_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + cu_seqlens=video_cu_seqlens, + rotary_pos_ids=video_rotary_pos_ids, return_dict=True, **kwargs, ) @@ -996,13 +998,15 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: image_outputs = self.vision_tower( pixel_values, image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, return_dict=True, **kwargs, ) diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 0af53440dba1..7602781dcf87 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -1101,6 +1101,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1120,8 +1122,8 @@ def get_video_features( vision_outputs = self.visual( pixel_values_videos, grid_thw=flattened_video_grid_thw, - cu_seqlens=kwargs.pop("video_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + cu_seqlens=video_cu_seqlens, + rotary_pos_ids=video_rotary_pos_ids, return_dict=True, **kwargs, ) @@ -1137,6 +1139,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1144,13 +1148,17 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from the processor). + image_rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed rotary position IDs (from the processor). """ pixel_values = pixel_values.type(self.visual.dtype) vision_outputs = self.visual( pixel_values, grid_thw=image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, **kwargs, ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -1371,6 +1379,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1380,7 +1390,11 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs + pixel_values_videos, + video_grid_thw, + video_cu_seqlens, + video_rotary_pos_ids, + **kwargs, ) @auto_docstring @@ -1388,6 +1402,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1396,7 +1412,13 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features( + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, + **kwargs, + ) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 88de6ee1828b..4cf37af4b4a0 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -789,6 +789,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -808,8 +810,8 @@ def get_video_features( vision_outputs = self.visual( pixel_values_videos, grid_thw=flattened_video_grid_thw, - cu_seqlens=kwargs.pop("video_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + cu_seqlens=video_cu_seqlens, + rotary_pos_ids=video_rotary_pos_ids, return_dict=True, **kwargs, ) diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index d53fe3922ae5..21a0f33994f2 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -1270,6 +1270,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1289,8 +1291,8 @@ def get_video_features( vision_outputs = self.visual( pixel_values_videos, grid_thw=flattened_video_grid_thw, - cu_seqlens=kwargs.pop("video_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + cu_seqlens=video_cu_seqlens, + rotary_pos_ids=video_rotary_pos_ids, return_dict=True, **kwargs, ) @@ -1306,6 +1308,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1313,13 +1317,17 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from the processor). + image_rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed rotary position IDs (from the processor). """ pixel_values = pixel_values.type(self.visual.dtype) vision_outputs = self.visual( pixel_values, grid_thw=image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, **kwargs, ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -1595,6 +1603,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1604,7 +1614,11 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs + pixel_values_videos, + video_grid_thw, + video_cu_seqlens, + video_rotary_pos_ids, + **kwargs, ) @auto_docstring @@ -1612,6 +1626,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1620,7 +1636,13 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features( + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, + **kwargs, + ) @auto_docstring @can_return_tuple diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py index 63da785d703f..fcbc4106d7e9 100644 --- a/src/transformers/models/glm_image/modeling_glm_image.py +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -1170,6 +1170,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1182,8 +1184,8 @@ def get_image_features( vision_outputs = self.visual( pixel_values, grid_thw=image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, return_dict=True, **kwargs, ) diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index ee5a4cbc23d6..6ad5a61ea1c1 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -709,6 +709,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -721,8 +723,8 @@ def get_image_features( vision_outputs = self.visual( pixel_values, grid_thw=image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, return_dict=True, **kwargs, ) diff --git a/src/transformers/models/glm_ocr/modeling_glm_ocr.py b/src/transformers/models/glm_ocr/modeling_glm_ocr.py index 3a8d6dd6634e..aa75dec92165 100644 --- a/src/transformers/models/glm_ocr/modeling_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modeling_glm_ocr.py @@ -1018,6 +1018,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1037,8 +1039,8 @@ def get_video_features( vision_outputs = self.visual( pixel_values_videos, grid_thw=flattened_video_grid_thw, - cu_seqlens=kwargs.pop("video_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + cu_seqlens=video_cu_seqlens, + rotary_pos_ids=video_rotary_pos_ids, return_dict=True, **kwargs, ) @@ -1054,6 +1056,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1061,13 +1065,17 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from the processor). + image_rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed rotary position IDs (from the processor). """ pixel_values = pixel_values.type(self.visual.dtype) vision_outputs = self.visual( pixel_values, grid_thw=image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, **kwargs, ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -1288,6 +1296,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1297,7 +1307,11 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs + pixel_values_videos, + video_grid_thw, + video_cu_seqlens, + video_rotary_pos_ids, + **kwargs, ) @auto_docstring @@ -1305,6 +1319,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1313,7 +1329,13 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features( + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, + **kwargs, + ) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 8275fe0324cb..d0e4c5027ae0 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1710,6 +1710,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1722,8 +1724,8 @@ def get_video_features( return self.visual( pixel_values_videos, grid_thw=video_grid_thw, - cu_seqlens=kwargs.pop("video_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + cu_seqlens=video_cu_seqlens, + rotary_pos_ids=video_rotary_pos_ids, **kwargs, ) @@ -1733,6 +1735,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1745,8 +1749,8 @@ def get_image_features( return self.visual( pixel_values, grid_thw=image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, **kwargs, ) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 44f7ffa33ef2..2d52008dfff4 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1720,6 +1720,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1732,8 +1734,8 @@ def get_video_features( return self.visual( pixel_values_videos, grid_thw=video_grid_thw, - cu_seqlens=kwargs.pop("video_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + cu_seqlens=video_cu_seqlens, + rotary_pos_ids=video_rotary_pos_ids, **kwargs, ) @@ -1743,6 +1745,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1755,8 +1759,8 @@ def get_image_features( return self.visual( pixel_values, grid_thw=image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, **kwargs, ) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index d98dd94bd770..fc623a11ba16 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1081,6 +1081,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1088,13 +1090,17 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from the processor). + video_rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed rotary position IDs (from the processor). """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) vision_outputs = self.visual( pixel_values_videos, grid_thw=video_grid_thw, - cu_seqlens=kwargs.pop("video_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + cu_seqlens=video_cu_seqlens, + rotary_pos_ids=video_rotary_pos_ids, **kwargs, ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -1109,6 +1115,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1116,13 +1124,17 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from the processor). + image_rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed rotary position IDs (from the processor). """ pixel_values = pixel_values.type(self.visual.dtype) vision_outputs = self.visual( pixel_values, grid_thw=image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, **kwargs, ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -1237,6 +1249,10 @@ def forward( rope_deltas: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, second_per_grid_ts: torch.Tensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen2_5_VLModelOutputWithPast: r""" @@ -1254,7 +1270,12 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw).pooler_output + image_embeds = self.get_image_features( + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, + ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds @@ -1262,7 +1283,12 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, + video_grid_thw, + video_cu_seqlens, + video_rotary_pos_ids, + ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds @@ -1352,6 +1378,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1361,7 +1389,11 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs + pixel_values_videos, + video_grid_thw, + video_cu_seqlens, + video_rotary_pos_ids, + **kwargs, ) @auto_docstring @@ -1369,6 +1401,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1377,7 +1411,13 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features( + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, + **kwargs, + ) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 82d7ab4c1ee3..754889619ddc 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -493,6 +493,10 @@ def forward( rope_deltas: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, second_per_grid_ts: torch.Tensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen2_5_VLModelOutputWithPast: r""" @@ -510,7 +514,9 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, image_cu_seqlens, image_rotary_pos_ids, + ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds @@ -518,7 +524,9 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, video_cu_seqlens, video_rotary_pos_ids, + ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 67ff370e5fd5..b82a03eb07c7 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1070,6 +1070,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1077,13 +1079,17 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from the processor). + video_rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed rotary position IDs (from the processor). """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) vision_outputs = self.visual( pixel_values_videos, grid_thw=video_grid_thw, - cu_seqlens=kwargs.pop("video_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + cu_seqlens=video_cu_seqlens, + rotary_pos_ids=video_rotary_pos_ids, **kwargs, ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -1098,6 +1104,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1105,13 +1113,17 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from the processor). + image_rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed rotary position IDs (from the processor). """ pixel_values = pixel_values.type(self.visual.dtype) vision_outputs = self.visual( pixel_values, grid_thw=image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, **kwargs, ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -1226,6 +1238,10 @@ def forward( video_grid_thw: torch.LongTensor | None = None, rope_deltas: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen2VLModelOutputWithPast: r""" @@ -1235,13 +1251,23 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. + image_cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths for images (from the processor). + image_rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed rotary position IDs for images (from the processor). + video_cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths for videos (from the processor). + video_rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed rotary position IDs for videos (from the processor). """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, image_cu_seqlens, image_rotary_pos_ids, + ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds @@ -1249,7 +1275,9 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, video_cu_seqlens, video_rotary_pos_ids, + ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds @@ -1306,6 +1334,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1315,7 +1345,7 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs + pixel_values_videos, video_grid_thw, video_cu_seqlens, video_rotary_pos_ids, **kwargs, ) @auto_docstring @@ -1323,6 +1353,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1331,7 +1363,9 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features( + pixel_values, image_grid_thw, image_cu_seqlens, image_rotary_pos_ids, **kwargs, + ) @can_return_tuple @auto_docstring @@ -1350,6 +1384,10 @@ def forward( video_grid_thw: torch.LongTensor | None = None, rope_deltas: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen2VLCausalLMOutputWithPast: @@ -1409,6 +1447,10 @@ def forward( image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, mm_token_type_ids=mm_token_type_ids, + image_cu_seqlens=image_cu_seqlens, + image_rotary_pos_ids=image_rotary_pos_ids, + video_cu_seqlens=video_cu_seqlens, + video_rotary_pos_ids=video_rotary_pos_ids, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index d2bb3fb7ccd1..e324ab25186f 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1396,6 +1396,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1405,7 +1407,13 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ # Same implementation as for images - return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) + return self.get_image_features( + pixel_values_videos, + video_grid_thw, + image_cu_seqlens=video_cu_seqlens, + image_rotary_pos_ids=video_rotary_pos_ids, + **kwargs, + ) @can_return_tuple @auto_docstring @@ -1413,6 +1421,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1425,8 +1435,8 @@ def get_image_features( vision_output: BaseModelOutputWithPooling = self.visual( pixel_values, grid_thw=image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, return_dict=True, **kwargs, ) @@ -1741,6 +1751,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1750,7 +1762,11 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs + pixel_values_videos, + video_grid_thw, + video_cu_seqlens, + video_rotary_pos_ids, + **kwargs, ) @auto_docstring @@ -1758,6 +1774,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1766,7 +1784,13 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features( + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, + **kwargs, + ) @can_return_tuple def forward( diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index 57c21714c8c8..99d17e3cc285 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -580,14 +580,16 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithPooling = self.visual( pixel_values, grid_thw=image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, return_dict=True, **kwargs, ) diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index db1fc3d961e6..e6d0244c880b 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1521,6 +1521,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1530,7 +1532,13 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ # Same implementation as for images - return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) + return self.get_image_features( + pixel_values_videos, + video_grid_thw, + image_cu_seqlens=video_cu_seqlens, + image_rotary_pos_ids=video_rotary_pos_ids, + **kwargs, + ) @can_return_tuple @auto_docstring @@ -1538,6 +1546,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1550,8 +1560,8 @@ def get_image_features( vision_output: BaseModelOutputWithPooling = self.visual( pixel_values, grid_thw=image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, return_dict=True, **kwargs, ) @@ -1940,6 +1950,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1949,7 +1961,11 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs + pixel_values_videos, + video_grid_thw, + video_cu_seqlens, + video_rotary_pos_ids, + **kwargs, ) @auto_docstring @@ -1957,6 +1973,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1965,7 +1983,13 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features( + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, + **kwargs, + ) @can_return_tuple def forward( diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 2f87ed0f0174..cd0988daa0c0 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1869,6 +1869,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1881,8 +1883,8 @@ def get_video_features( return self.visual( pixel_values_videos, grid_thw=video_grid_thw, - cu_seqlens=kwargs.pop("video_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + cu_seqlens=video_cu_seqlens, + rotary_pos_ids=video_rotary_pos_ids, **kwargs, ) @@ -1892,6 +1894,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1904,8 +1908,8 @@ def get_image_features( return self.visual( pixel_values, grid_thw=image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, **kwargs, ) diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index 7d0acfd8bd12..fa2cbdb4e77d 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -1202,6 +1202,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1214,8 +1216,8 @@ def get_video_features( return self.visual( pixel_values_videos, grid_thw=video_grid_thw, - cu_seqlens=kwargs.pop("video_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + cu_seqlens=video_cu_seqlens, + rotary_pos_ids=video_rotary_pos_ids, **kwargs, ) @@ -1225,6 +1227,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1237,8 +1241,8 @@ def get_image_features( return self.visual( pixel_values, grid_thw=image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, **kwargs, ) diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 1251c49462b8..3661dfff8434 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -1038,6 +1038,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1047,7 +1049,13 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ # Same implementation as for images - return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) + return self.get_image_features( + pixel_values_videos, + video_grid_thw, + image_cu_seqlens=video_cu_seqlens, + image_rotary_pos_ids=video_rotary_pos_ids, + **kwargs, + ) @can_return_tuple @auto_docstring @@ -1055,6 +1063,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1067,8 +1077,8 @@ def get_image_features( vision_output: BaseModelOutputWithDeepstackFeatures = self.visual( pixel_values, grid_thw=image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, return_dict=True, **kwargs, ) @@ -1183,6 +1193,10 @@ def forward( image_grid_thw: torch.LongTensor | None = None, video_grid_thw: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen3VLModelOutputWithPast: r""" @@ -1202,7 +1216,11 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features( - pixel_values, image_grid_thw, return_dict=True + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, + return_dict=True, ) image_embeds = image_outputs.pooler_output deepstack_image_embeds = image_outputs.deepstack_features @@ -1214,7 +1232,11 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features( - pixel_values_videos, video_grid_thw, return_dict=True + pixel_values_videos, + video_grid_thw, + video_cu_seqlens, + video_rotary_pos_ids, + return_dict=True, ) video_embeds = video_outputs.pooler_output deepstack_video_embeds = video_outputs.deepstack_features @@ -1329,6 +1351,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1338,7 +1362,11 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs + pixel_values_videos, + video_grid_thw, + video_cu_seqlens, + video_rotary_pos_ids, + **kwargs, ) @auto_docstring @@ -1346,6 +1374,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1354,7 +1384,13 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features( + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, + **kwargs, + ) @can_return_tuple def forward( diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index eda65088474b..2b64bb263045 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -690,6 +690,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -702,8 +704,8 @@ def get_image_features( vision_output: BaseModelOutputWithDeepstackFeatures = self.visual( pixel_values, grid_thw=image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, return_dict=True, **kwargs, ) @@ -720,6 +722,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -729,7 +733,7 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ # Same implementation as for images - return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) + return self.get_image_features(pixel_values_videos, video_grid_thw, image_cu_seqlens=video_cu_seqlens, image_rotary_pos_ids=video_rotary_pos_ids, **kwargs) @auto_docstring @can_return_tuple @@ -745,6 +749,10 @@ def forward( image_grid_thw: torch.LongTensor | None = None, video_grid_thw: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen3VLModelOutputWithPast: r""" @@ -764,7 +772,8 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features( - pixel_values, image_grid_thw, return_dict=True + pixel_values, image_grid_thw, image_cu_seqlens, image_rotary_pos_ids, + return_dict=True, ) image_embeds = image_outputs.pooler_output deepstack_image_embeds = image_outputs.deepstack_features @@ -776,7 +785,8 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features( - pixel_values_videos, video_grid_thw, return_dict=True + pixel_values_videos, video_grid_thw, video_cu_seqlens, video_rotary_pos_ids, + return_dict=True, ) video_embeds = video_outputs.pooler_output deepstack_video_embeds = video_outputs.deepstack_features diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 008b48ae78b0..69412bbe7c19 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -1168,6 +1168,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1177,7 +1179,13 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ # Same implementation as for images - return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) + return self.get_image_features( + pixel_values_videos, + video_grid_thw, + image_cu_seqlens=video_cu_seqlens, + image_rotary_pos_ids=video_rotary_pos_ids, + **kwargs, + ) @can_return_tuple @auto_docstring @@ -1185,6 +1193,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1197,8 +1207,8 @@ def get_image_features( vision_output: BaseModelOutputWithDeepstackFeatures = self.visual( pixel_values, grid_thw=image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, return_dict=True, **kwargs, ) @@ -1313,6 +1323,10 @@ def forward( image_grid_thw: torch.LongTensor | None = None, video_grid_thw: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen3VLMoeModelOutputWithPast: r""" @@ -1332,7 +1346,11 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features( - pixel_values, image_grid_thw, return_dict=True + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, + return_dict=True, ) image_embeds = image_outputs.pooler_output deepstack_image_embeds = image_outputs.deepstack_features @@ -1344,7 +1362,11 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features( - pixel_values_videos, video_grid_thw, return_dict=True + pixel_values_videos, + video_grid_thw, + video_cu_seqlens, + video_rotary_pos_ids, + return_dict=True, ) video_embeds = video_outputs.pooler_output deepstack_video_embeds = video_outputs.deepstack_features @@ -1512,6 +1534,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1521,7 +1545,11 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs + pixel_values_videos, + video_grid_thw, + video_cu_seqlens, + video_rotary_pos_ids, + **kwargs, ) @auto_docstring @@ -1529,6 +1557,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1537,7 +1567,13 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features( + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, + **kwargs, + ) @can_return_tuple def forward( diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index cfc594a84f17..ccc9568c440e 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -527,6 +527,8 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor, video_merge_sizes: torch.LongTensor, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -541,6 +543,8 @@ def get_video_features( pixel_values=pixel_values_videos, image_grid_thw=video_grid_thw, image_merge_sizes=video_merge_sizes, + image_cu_seqlens=video_cu_seqlens, + image_rotary_pos_ids=video_rotary_pos_ids, **kwargs, ) @@ -551,6 +555,8 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor, image_merge_sizes: torch.LongTensor, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -565,8 +571,8 @@ def get_image_features( pixel_values=pixel_values, grid_thw=image_grid_thw, merge_sizes=image_merge_sizes, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, return_dict=True, **kwargs, ) @@ -755,6 +761,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -764,7 +772,11 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs + pixel_values_videos, + video_grid_thw, + video_cu_seqlens, + video_rotary_pos_ids, + **kwargs, ) @auto_docstring @@ -772,6 +784,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -780,7 +794,13 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features( + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, + **kwargs, + ) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index a22c5321c048..18971e887788 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -488,6 +488,8 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor, video_merge_sizes: torch.LongTensor, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -502,6 +504,8 @@ def get_video_features( pixel_values=pixel_values_videos, image_grid_thw=video_grid_thw, image_merge_sizes=video_merge_sizes, + image_cu_seqlens=video_cu_seqlens, + image_rotary_pos_ids=video_rotary_pos_ids, **kwargs, ) @@ -512,6 +516,8 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor, image_merge_sizes: torch.LongTensor, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -526,8 +532,8 @@ def get_image_features( pixel_values=pixel_values, grid_thw=image_grid_thw, merge_sizes=image_merge_sizes, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, return_dict=True, **kwargs, ) From 2a014a4a9df901c20637d9433ce6078c5e37763f Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 13:16:44 +0200 Subject: [PATCH 11/56] propagate inputs --- .../modeling_ernie4_5_vl_moe.py | 10 +++- .../modular_ernie4_5_vl_moe.py | 2 +- .../models/glm46v/modeling_glm46v.py | 46 ++++++++++++++++--- .../models/glm4v/modeling_glm4v.py | 14 +++++- .../models/glm4v/modular_glm4v.py | 6 ++- .../models/glm4v_moe/modeling_glm4v_moe.py | 14 +++++- .../models/glm_image/modeling_glm_image.py | 4 ++ .../models/glm_image/modular_glm_image.py | 4 ++ .../models/glm_ocr/modeling_glm_ocr.py | 14 +++++- .../models/glm_ocr/modular_glm_ocr.py | 2 +- .../paddleocr_vl/modeling_paddleocr_vl.py | 14 +++++- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 45 +++++++++++------- .../qwen2_5_omni/modular_qwen2_5_omni.py | 45 +++++++++++------- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 18 +++++++- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 20 ++++++-- .../models/qwen2_vl/modeling_qwen2_vl.py | 40 ++++++++++++++-- .../models/qwen3_5/modeling_qwen3_5.py | 18 +++++++- .../models/qwen3_5/modular_qwen3_5.py | 2 +- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 18 +++++++- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 10 +++- .../qwen3_omni_moe/modular_qwen3_omni_moe.py | 8 ++++ .../models/qwen3_vl/modeling_qwen3_vl.py | 26 ++++++++++- .../models/qwen3_vl/modular_qwen3_vl.py | 28 +++++++++-- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 26 ++++++++++- .../video_llama_3/modeling_video_llama_3.py | 22 ++++++++- .../video_llama_3/modular_video_llama_3.py | 14 +++++- 26 files changed, 401 insertions(+), 69 deletions(-) diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index 16a7d2359e76..b2f819459fd0 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -915,7 +915,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -1599,6 +1599,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + video_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, @@ -1622,6 +1626,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + image_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, diff --git a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py index b11d2426a652..a701d4e006a9 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py @@ -711,7 +711,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) diff --git a/src/transformers/models/glm46v/modeling_glm46v.py b/src/transformers/models/glm46v/modeling_glm46v.py index 3d65679d6a2a..df3884ee05c1 100644 --- a/src/transformers/models/glm46v/modeling_glm46v.py +++ b/src/transformers/models/glm46v/modeling_glm46v.py @@ -282,6 +282,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -289,6 +291,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames @@ -301,8 +307,8 @@ def get_video_features( vision_outputs = self.visual( pixel_values_videos, grid_thw=flattened_video_grid_thw, - cu_seqlens=kwargs.pop("video_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("video_rotary_pos_ids", None), + cu_seqlens=video_cu_seqlens, + rotary_pos_ids=video_rotary_pos_ids, return_dict=True, **kwargs, ) @@ -318,6 +324,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -325,13 +333,17 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.Tensor`, *optional*): + Precomputed cumulative sequence lengths (from the processor). + image_rotary_pos_ids (`torch.Tensor`, *optional*): + Precomputed rotary position IDs (from the processor). """ pixel_values = pixel_values.type(self.visual.dtype) vision_outputs = self.visual( pixel_values, grid_thw=image_grid_thw, - cu_seqlens=kwargs.pop("image_cu_seqlens", None), - rotary_pos_ids=kwargs.pop("image_rotary_pos_ids", None), + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, **kwargs, ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -552,6 +564,8 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -559,9 +573,17 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + video_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs + pixel_values_videos, + video_grid_thw, + video_cu_seqlens, + video_rotary_pos_ids, + **kwargs, ) @auto_docstring @@ -569,6 +591,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -576,8 +600,18 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + image_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features( + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, + **kwargs, + ) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 7602781dcf87..d679499a6be4 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -758,7 +758,7 @@ def forward( if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) - rotary_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -1110,6 +1110,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames @@ -1388,6 +1392,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + video_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, @@ -1411,6 +1419,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + image_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 4cf37af4b4a0..858be7448ade 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -643,7 +643,7 @@ def forward( if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) - rotary_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -798,6 +798,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (): + + video_rotary_pos_ids (): + """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 21a0f33994f2..fb3de2fda9f3 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -822,7 +822,7 @@ def forward( if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) - rotary_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -1279,6 +1279,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames @@ -1612,6 +1616,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + video_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, @@ -1635,6 +1643,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + image_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py index fcbc4106d7e9..a1fe288ad59b 100644 --- a/src/transformers/models/glm_image/modeling_glm_image.py +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -1179,6 +1179,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) vision_outputs = self.visual( diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index 6ad5a61ea1c1..993a79c17485 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -718,6 +718,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (): + + image_rotary_pos_ids (): + """ pixel_values = pixel_values.type(self.visual.dtype) vision_outputs = self.visual( diff --git a/src/transformers/models/glm_ocr/modeling_glm_ocr.py b/src/transformers/models/glm_ocr/modeling_glm_ocr.py index aa75dec92165..f5b91e1cf89d 100644 --- a/src/transformers/models/glm_ocr/modeling_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modeling_glm_ocr.py @@ -609,7 +609,7 @@ def forward( if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) - rotary_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -1027,6 +1027,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames @@ -1305,6 +1309,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + video_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, @@ -1328,6 +1336,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + image_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, diff --git a/src/transformers/models/glm_ocr/modular_glm_ocr.py b/src/transformers/models/glm_ocr/modular_glm_ocr.py index 5fc29c0c131a..a2cf56f94954 100644 --- a/src/transformers/models/glm_ocr/modular_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modular_glm_ocr.py @@ -276,7 +276,7 @@ def forward( if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) - rotary_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 9e3e83434958..38451c2074e0 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -1398,6 +1398,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1405,8 +1407,18 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + image_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features( + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, + **kwargs, + ) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index d0e4c5027ae0..bd13ea077d2b 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -776,20 +776,17 @@ def get_valid_indices(chunk_lengths: torch.Tensor) -> torch.Tensor: def get_pool_indices(feature_lens: torch.Tensor) -> torch.Tensor: - """Compute indices for stride-2 pooling over post-CNN audio features. - - Selects every other position (even indices) from each sample's post-CNN - features, accounting for two convolution stages and variable-length samples. + """Compute indices for post-encoder stride-2 average pooling. Args: - feature_lens: ``(batch_size,)`` per-sample raw frame counts. + feature_lens: ``(batch_size,)`` mel spectrogram lengths. Returns: - ``(total_pairs,)`` flat indices for stride-2 pooling across concatenated samples. + ``(total_pooled,)`` flat index of first element of each pair. """ after_conv1 = (feature_lens - 1) // 2 + 1 after_conv2 = (after_conv1 - 2) // 2 + 1 - num_pairs = (after_conv2 - 1 + 1) // 2 + num_pairs = after_conv2 // 2 offsets = F.pad(after_conv2[:-1].cumsum(0), (1, 0), value=0) pair_offsets = torch.repeat_interleave(offsets, num_pairs) local_indices = torch.arange(num_pairs.sum(), device=feature_lens.device) @@ -853,6 +850,7 @@ def forward( self, input_features=None, feature_lens=None, + aftercnn_lens=None, padded_feature=None, chunk_lengths=None, valid_indices=None, @@ -863,6 +861,8 @@ def forward( r""" feature_lens (`torch.LongTensor` of shape `(batch_size,)`): mel length + aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + mel length after cnn padded_feature (`torch.FloatTensor`, *optional*): Precomputed padded audio chunks (from `chunk_and_pad_features`). chunk_lengths (`torch.LongTensor`, *optional*): @@ -870,9 +870,9 @@ def forward( valid_indices (`torch.LongTensor`, *optional*): Precomputed flat indices of valid post-CNN positions (from `get_valid_indices`). pool_indices (`torch.LongTensor`, *optional*): - Precomputed pair indices for post-encoder average pooling (from `get_pool_indices`). + Precomputed pair indices for stride-2 average pooling (from `get_pool_indices`). cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). + Precomputed cumulative sequence lengths (from `get_audio_cu_seqlens`). """ if padded_feature is None: padded_feature, chunk_lengths = chunk_and_pad_features(input_features, feature_lens, self.n_window) @@ -883,7 +883,11 @@ def forward( if pool_indices is None: pool_indices = get_pool_indices(feature_lens) + if aftercnn_lens is None and feature_lens is not None: + aftercnn_lens, _ = self._get_feat_extract_output_lengths(feature_lens) + # Derive masks from chunk_lengths (traceable arithmetic + arange broadcasting) + padded_feature = padded_feature.to(self.conv1.weight.dtype) padded_mask = ( (torch.arange(padded_feature.shape[2], device=padded_feature.device) < chunk_lengths.unsqueeze(1)) .unsqueeze(1) @@ -926,11 +930,10 @@ def forward( ) hidden_states = layer_outputs[0] - # Post-process: average consecutive pairs per audio, then project - pooled = (hidden_states[pool_indices] + hidden_states[pool_indices + 1]) / 2 - pooled = self.ln_post(pooled) - token_audio = self.proj(pooled) - return BaseModelOutputWithPooling(last_hidden_state=token_audio) + # Post-process: stride-2 average pooling using precomputed indices, then project + hidden_states = (hidden_states[pool_indices] + hidden_states[pool_indices + 1]) / 2 + hidden_states = self.proj(self.ln_post(hidden_states)) + return BaseModelOutputWithPooling(last_hidden_state=hidden_states) # Ignore copy def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): @@ -1240,7 +1243,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -1719,6 +1722,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) return self.visual( @@ -1744,6 +1751,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) return self.visual( @@ -1781,7 +1792,9 @@ def get_audio_features( audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) ) feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) - audio_outputs = self.audio_tower(input_features, feature_lens=feature_lens, return_dict=True, **kwargs) + audio_outputs = self.audio_tower( + input_features, feature_lens=feature_lens, aftercnn_lens=audio_feat_lengths, return_dict=True, **kwargs + ) if audio_outputs.last_hidden_state.shape[0] != sum(audio_output_lengths.tolist()): raise ValueError("length of audio_features should match audio_output_lengths") diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 2d52008dfff4..0764c1ee9a7f 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -126,20 +126,17 @@ def get_valid_indices(chunk_lengths: torch.Tensor) -> torch.Tensor: def get_pool_indices(feature_lens: torch.Tensor) -> torch.Tensor: - """Compute indices for stride-2 pooling over post-CNN audio features. - - Selects every other position (even indices) from each sample's post-CNN - features, accounting for two convolution stages and variable-length samples. + """Compute indices for post-encoder stride-2 average pooling. Args: - feature_lens: ``(batch_size,)`` per-sample raw frame counts. + feature_lens: ``(batch_size,)`` mel spectrogram lengths. Returns: - ``(total_pairs,)`` flat indices for stride-2 pooling across concatenated samples. + ``(total_pooled,)`` flat index of first element of each pair. """ after_conv1 = (feature_lens - 1) // 2 + 1 after_conv2 = (after_conv1 - 2) // 2 + 1 - num_pairs = (after_conv2 - 1 + 1) // 2 + num_pairs = after_conv2 // 2 offsets = F.pad(after_conv2[:-1].cumsum(0), (1, 0), value=0) pair_offsets = torch.repeat_interleave(offsets, num_pairs) local_indices = torch.arange(num_pairs.sum(), device=feature_lens.device) @@ -1335,6 +1332,7 @@ def forward( self, input_features=None, feature_lens=None, + aftercnn_lens=None, padded_feature=None, chunk_lengths=None, valid_indices=None, @@ -1345,6 +1343,8 @@ def forward( r""" feature_lens (`torch.LongTensor` of shape `(batch_size,)`): mel length + aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + mel length after cnn padded_feature (`torch.FloatTensor`, *optional*): Precomputed padded audio chunks (from `chunk_and_pad_features`). chunk_lengths (`torch.LongTensor`, *optional*): @@ -1352,9 +1352,9 @@ def forward( valid_indices (`torch.LongTensor`, *optional*): Precomputed flat indices of valid post-CNN positions (from `get_valid_indices`). pool_indices (`torch.LongTensor`, *optional*): - Precomputed pair indices for post-encoder average pooling (from `get_pool_indices`). + Precomputed pair indices for stride-2 average pooling (from `get_pool_indices`). cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). + Precomputed cumulative sequence lengths (from `get_audio_cu_seqlens`). """ if padded_feature is None: padded_feature, chunk_lengths = chunk_and_pad_features(input_features, feature_lens, self.n_window) @@ -1365,7 +1365,11 @@ def forward( if pool_indices is None: pool_indices = get_pool_indices(feature_lens) + if aftercnn_lens is None and feature_lens is not None: + aftercnn_lens, _ = self._get_feat_extract_output_lengths(feature_lens) + # Derive masks from chunk_lengths (traceable arithmetic + arange broadcasting) + padded_feature = padded_feature.to(self.conv1.weight.dtype) padded_mask = ( (torch.arange(padded_feature.shape[2], device=padded_feature.device) < chunk_lengths.unsqueeze(1)) .unsqueeze(1) @@ -1408,11 +1412,10 @@ def forward( ) hidden_states = layer_outputs[0] - # Post-process: average consecutive pairs per audio, then project - pooled = (hidden_states[pool_indices] + hidden_states[pool_indices + 1]) / 2 - pooled = self.ln_post(pooled) - token_audio = self.proj(pooled) - return BaseModelOutputWithPooling(last_hidden_state=token_audio) + # Post-process: stride-2 average pooling using precomputed indices, then project + hidden_states = (hidden_states[pool_indices] + hidden_states[pool_indices + 1]) / 2 + hidden_states = self.proj(self.ln_post(hidden_states)) + return BaseModelOutputWithPooling(last_hidden_state=hidden_states) # Ignore copy def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): @@ -1595,7 +1598,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -1729,6 +1732,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (): + + video_rotary_pos_ids (): + """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) return self.visual( @@ -1754,6 +1761,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (): + + image_rotary_pos_ids (): + """ pixel_values = pixel_values.type(self.visual.dtype) return self.visual( @@ -1791,7 +1802,9 @@ def get_audio_features( audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) ) feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) - audio_outputs = self.audio_tower(input_features, feature_lens=feature_lens, return_dict=True, **kwargs) + audio_outputs = self.audio_tower( + input_features, feature_lens=feature_lens, aftercnn_lens=audio_feat_lengths, return_dict=True, **kwargs + ) if audio_outputs.last_hidden_state.shape[0] != sum(audio_output_lengths.tolist()): raise ValueError("length of audio_features should match audio_output_lengths") diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index fc623a11ba16..ee0e6aa239ef 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -412,7 +412,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -1264,6 +1264,14 @@ def forward( The rope index difference between sequence length and multimodal rope. second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ if inputs_embeds is None: @@ -1387,6 +1395,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + video_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, @@ -1410,6 +1422,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + image_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 754889619ddc..8f43e21b6c92 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -253,7 +253,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -508,6 +508,14 @@ def forward( The rope index difference between sequence length and multimodal rope. second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + image_cu_seqlens (): + + image_rotary_pos_ids (): + + video_cu_seqlens (): + + video_rotary_pos_ids (): + """ if inputs_embeds is None: @@ -515,7 +523,10 @@ def forward( if pixel_values is not None: image_embeds = self.get_image_features( - pixel_values, image_grid_thw, image_cu_seqlens, image_rotary_pos_ids, + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( @@ -525,7 +536,10 @@ def forward( if pixel_values_videos is not None: video_embeds = self.get_video_features( - pixel_values_videos, video_grid_thw, video_cu_seqlens, video_rotary_pos_ids, + pixel_values_videos, + video_grid_thw, + video_cu_seqlens, + video_rotary_pos_ids, ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask( diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index b82a03eb07c7..e88ad7497e0b 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -744,7 +744,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -1266,7 +1266,10 @@ def forward( if pixel_values is not None: image_embeds = self.get_image_features( - pixel_values, image_grid_thw, image_cu_seqlens, image_rotary_pos_ids, + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( @@ -1276,7 +1279,10 @@ def forward( if pixel_values_videos is not None: video_embeds = self.get_video_features( - pixel_values_videos, video_grid_thw, video_cu_seqlens, video_rotary_pos_ids, + pixel_values_videos, + video_grid_thw, + video_cu_seqlens, + video_rotary_pos_ids, ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask( @@ -1343,9 +1349,17 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + video_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( - pixel_values_videos, video_grid_thw, video_cu_seqlens, video_rotary_pos_ids, **kwargs, + pixel_values_videos, + video_grid_thw, + video_cu_seqlens, + video_rotary_pos_ids, + **kwargs, ) @auto_docstring @@ -1362,9 +1376,17 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + image_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( - pixel_values, image_grid_thw, image_cu_seqlens, image_rotary_pos_ids, **kwargs, + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, + **kwargs, ) @can_return_tuple @@ -1402,6 +1424,14 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. + image_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + image_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). + video_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + video_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). Example: diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index e324ab25186f..1c985ecb588a 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1069,7 +1069,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -1405,6 +1405,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ # Same implementation as for images return self.get_image_features( @@ -1430,6 +1434,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithPooling = self.visual( @@ -1760,6 +1768,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + video_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, @@ -1783,6 +1795,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + image_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index 99d17e3cc285..3bded3a51c66 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -462,7 +462,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index e6d0244c880b..4fe9c4971aea 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1162,7 +1162,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -1530,6 +1530,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ # Same implementation as for images return self.get_image_features( @@ -1555,6 +1559,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithPooling = self.visual( @@ -1959,6 +1967,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + video_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, @@ -1982,6 +1994,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + image_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index cd0988daa0c0..55909fdd4481 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1159,7 +1159,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -1878,6 +1878,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) return self.visual( @@ -1903,6 +1907,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) return self.visual( diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index fa2cbdb4e77d..56e7e56e47f8 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -1211,6 +1211,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (): + + video_rotary_pos_ids (): + """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) return self.visual( @@ -1236,6 +1240,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (): + + image_rotary_pos_ids (): + """ pixel_values = pixel_values.type(self.visual.dtype) return self.visual( diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 3661dfff8434..206634dafbbe 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -700,7 +700,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -1047,6 +1047,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ # Same implementation as for images return self.get_image_features( @@ -1072,6 +1076,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithDeepstackFeatures = self.visual( @@ -1204,6 +1212,14 @@ def forward( The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -1360,6 +1376,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + video_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, @@ -1383,6 +1403,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + image_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 2b64bb263045..a7a95f1b4d4c 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -488,7 +488,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -699,6 +699,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (): + + image_rotary_pos_ids (): + """ pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithDeepstackFeatures = self.visual( @@ -731,9 +735,19 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (): + + video_rotary_pos_ids (): + """ # Same implementation as for images - return self.get_image_features(pixel_values_videos, video_grid_thw, image_cu_seqlens=video_cu_seqlens, image_rotary_pos_ids=video_rotary_pos_ids, **kwargs) + return self.get_image_features( + pixel_values_videos, + video_grid_thw, + image_cu_seqlens=video_cu_seqlens, + image_rotary_pos_ids=video_rotary_pos_ids, + **kwargs, + ) @auto_docstring @can_return_tuple @@ -772,7 +786,10 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features( - pixel_values, image_grid_thw, image_cu_seqlens, image_rotary_pos_ids, + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, return_dict=True, ) image_embeds = image_outputs.pooler_output @@ -785,7 +802,10 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features( - pixel_values_videos, video_grid_thw, video_cu_seqlens, video_rotary_pos_ids, + pixel_values_videos, + video_grid_thw, + video_cu_seqlens, + video_rotary_pos_ids, return_dict=True, ) video_embeds = video_outputs.pooler_output diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 69412bbe7c19..ce64184920d6 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -685,7 +685,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -1177,6 +1177,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ # Same implementation as for images return self.get_image_features( @@ -1202,6 +1206,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithDeepstackFeatures = self.visual( @@ -1334,6 +1342,14 @@ def forward( The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -1543,6 +1559,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + video_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, @@ -1566,6 +1586,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + image_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index ccc9568c440e..de5866a080d1 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -428,17 +428,19 @@ def forward( rotary_pos_ids (`torch.Tensor`, *optional*): Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). """ + + hidden_states = self.embeddings(pixel_values.type(self.dtype)) + if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, merge_sizes) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) - hidden_states = self.embeddings(pixel_values.type(self.dtype)) encoder_outputs: BaseModelOutput = self.encoder( hidden_states, cu_seqlens=cu_seqlens, @@ -538,6 +540,10 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. video_merge_sizes (`torch.Tensor` of shape `(num_videos,)`): The spatial downsampling ratio of each video feature. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ return self.get_image_features( pixel_values=pixel_values_videos, @@ -566,6 +572,10 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. image_merge_sizes (`torch.Tensor` of shape `(num_images,)`): The spatial downsampling ratio of each image feature. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. """ vision_outputs = self.vision_model( pixel_values=pixel_values, @@ -770,6 +780,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + video_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, @@ -793,6 +807,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + image_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index 18971e887788..46c2fcc1a8c8 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -390,17 +390,19 @@ def forward( rotary_pos_ids (`torch.Tensor`, *optional*): Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). """ + + hidden_states = self.embeddings(pixel_values.type(self.dtype)) + if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, merge_sizes) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) - hidden_states = self.embeddings(pixel_values.type(self.dtype)) encoder_outputs: BaseModelOutput = self.encoder( hidden_states, cu_seqlens=cu_seqlens, @@ -499,6 +501,10 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. video_merge_sizes (`torch.Tensor` of shape `(num_videos,)`): The spatial downsampling ratio of each video feature. + video_cu_seqlens (): + + video_rotary_pos_ids (): + """ return self.get_image_features( pixel_values=pixel_values_videos, @@ -527,6 +533,10 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. image_merge_sizes (`torch.Tensor` of shape `(num_images,)`): The spatial downsampling ratio of each image feature. + image_cu_seqlens (): + + image_rotary_pos_ids (): + """ vision_outputs = self.vision_model( pixel_values=pixel_values, From 957372a110ca3fcadbc78223dd3f1484955f6459 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 13:18:27 +0200 Subject: [PATCH 12/56] fix docs --- src/transformers/models/glm4v/modular_glm4v.py | 8 ++++---- .../models/glm_image/modular_glm_image.py | 8 ++++---- .../models/qwen2_5_omni/modular_qwen2_5_omni.py | 16 ++++++++-------- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 16 ++++++++-------- .../qwen3_omni_moe/modular_qwen3_omni_moe.py | 16 ++++++++-------- .../models/qwen3_vl/modular_qwen3_vl.py | 16 ++++++++-------- .../video_llama_3/modular_video_llama_3.py | 16 ++++++++-------- 7 files changed, 48 insertions(+), 48 deletions(-) diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 858be7448ade..cc27197856a5 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -798,10 +798,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (): - - video_rotary_pos_ids (): - + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index 993a79c17485..58696e869ed8 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -718,10 +718,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (): - - image_rotary_pos_ids (): - + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) vision_outputs = self.visual( diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 0764c1ee9a7f..4c26d2a1883d 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1732,10 +1732,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (): - - video_rotary_pos_ids (): - + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) return self.visual( @@ -1761,10 +1761,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (): - - image_rotary_pos_ids (): - + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) return self.visual( diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 8f43e21b6c92..fb47f4250e33 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -508,14 +508,14 @@ def forward( The rope index difference between sequence length and multimodal rope. second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. - image_cu_seqlens (): - - image_rotary_pos_ids (): - - video_cu_seqlens (): - - video_rotary_pos_ids (): - + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ if inputs_embeds is None: diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index 56e7e56e47f8..84ddb7d6e7be 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -1211,10 +1211,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (): - - video_rotary_pos_ids (): - + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) return self.visual( @@ -1240,10 +1240,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (): - - image_rotary_pos_ids (): - + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) return self.visual( diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index a7a95f1b4d4c..7d29cc48c57d 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -699,10 +699,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (): - - image_rotary_pos_ids (): - + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithDeepstackFeatures = self.visual( @@ -735,10 +735,10 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (): - - video_rotary_pos_ids (): - + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ # Same implementation as for images return self.get_image_features( diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index 46c2fcc1a8c8..8c2fc062c592 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -501,10 +501,10 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. video_merge_sizes (`torch.Tensor` of shape `(num_videos,)`): The spatial downsampling ratio of each video feature. - video_cu_seqlens (): - - video_rotary_pos_ids (): - + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ return self.get_image_features( pixel_values=pixel_values_videos, @@ -533,10 +533,10 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. image_merge_sizes (`torch.Tensor` of shape `(num_images,)`): The spatial downsampling ratio of each image feature. - image_cu_seqlens (): - - image_rotary_pos_ids (): - + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. """ vision_outputs = self.vision_model( pixel_values=pixel_values, From 4194ff1fdd585178559c52285b248739c3116cab Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 13:27:19 +0200 Subject: [PATCH 13/56] fix docs --- .../ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py | 16 ++++++++++++++++ .../ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py | 16 ++++++++++++++++ .../models/glm46v/modeling_glm46v.py | 12 ++++++++++++ src/transformers/models/glm4v/modeling_glm4v.py | 12 ++++++++++++ src/transformers/models/glm4v/modular_glm4v.py | 12 ++++++++++++ .../models/glm4v_moe/modeling_glm4v_moe.py | 12 ++++++++++++ .../models/glm_ocr/modeling_glm_ocr.py | 12 ++++++++++++ .../models/qwen3_5/modeling_qwen3_5.py | 8 ++++++++ .../models/qwen3_5/modular_qwen3_5.py | 8 ++++++++ .../models/qwen3_5_moe/modeling_qwen3_5_moe.py | 8 ++++++++ .../models/qwen3_vl/modular_qwen3_vl.py | 8 ++++++++ 11 files changed, 124 insertions(+) diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index b2f819459fd0..2a151a502d26 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -1429,6 +1429,14 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -1676,6 +1684,14 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.text_config.output_router_logits diff --git a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py index a701d4e006a9..a0569756ccbf 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py @@ -1046,6 +1046,14 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -1185,6 +1193,14 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.text_config.output_router_logits diff --git a/src/transformers/models/glm46v/modeling_glm46v.py b/src/transformers/models/glm46v/modeling_glm46v.py index df3884ee05c1..46677bdd79da 100644 --- a/src/transformers/models/glm46v/modeling_glm46v.py +++ b/src/transformers/models/glm46v/modeling_glm46v.py @@ -295,6 +295,10 @@ def get_video_features( Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): Precomputed (row, col) position IDs for video rotary embeddings. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames @@ -467,6 +471,14 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index d679499a6be4..7ac19671f4a7 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -1114,6 +1114,10 @@ def get_video_features( Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): Precomputed (row, col) position IDs for video rotary embeddings. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames @@ -1286,6 +1290,14 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index cc27197856a5..e2510038c55d 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -802,6 +802,10 @@ def get_video_features( Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): Precomputed (row, col) position IDs for video rotary embeddings. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames @@ -1008,6 +1012,14 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index fb3de2fda9f3..586cebd3f3a3 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -1283,6 +1283,10 @@ def get_video_features( Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): Precomputed (row, col) position IDs for video rotary embeddings. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames @@ -1455,6 +1459,14 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") diff --git a/src/transformers/models/glm_ocr/modeling_glm_ocr.py b/src/transformers/models/glm_ocr/modeling_glm_ocr.py index f5b91e1cf89d..fd9fb53434ec 100644 --- a/src/transformers/models/glm_ocr/modeling_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modeling_glm_ocr.py @@ -1031,6 +1031,10 @@ def get_video_features( Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): Precomputed (row, col) position IDs for video rotary embeddings. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames @@ -1203,6 +1207,14 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index 1c985ecb588a..c9ca20fc14c0 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1566,6 +1566,14 @@ def forward( The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index 3bded3a51c66..a910724b3fe6 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -621,6 +621,14 @@ def forward( The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 4fe9c4971aea..ac2198d5878c 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1691,6 +1691,14 @@ def forward( The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 7d29cc48c57d..8e2949d699fd 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -774,6 +774,14 @@ def forward( The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. + video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") From 836424bf9196892a43c47cc350f44d1821ca2f4f Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 13:42:57 +0200 Subject: [PATCH 14/56] auto docs --- .../modeling_ernie4_5_vl_moe.py | 16 ----------- .../modular_ernie4_5_vl_moe.py | 16 ----------- .../models/glm46v/modeling_glm46v.py | 16 ----------- .../models/glm4v/modeling_glm4v.py | 16 ----------- .../models/glm4v/modular_glm4v.py | 16 ----------- .../models/glm4v_moe/modeling_glm4v_moe.py | 16 ----------- .../models/glm_image/modeling_glm_image.py | 4 --- .../models/glm_image/modular_glm_image.py | 4 --- .../models/glm_ocr/modeling_glm_ocr.py | 16 ----------- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 8 ------ .../qwen2_5_omni/modular_qwen2_5_omni.py | 4 --- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 8 ------ .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 8 ------ .../models/qwen3_5/modeling_qwen3_5.py | 16 ----------- .../models/qwen3_5/modular_qwen3_5.py | 8 ------ .../qwen3_5_moe/modeling_qwen3_5_moe.py | 16 ----------- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 8 ------ .../qwen3_omni_moe/modular_qwen3_omni_moe.py | 4 --- .../models/qwen3_vl/modeling_qwen3_vl.py | 16 ----------- .../models/qwen3_vl/modular_qwen3_vl.py | 12 -------- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 16 ----------- .../video_llama_3/modeling_video_llama_3.py | 8 ------ .../video_llama_3/modular_video_llama_3.py | 4 --- src/transformers/utils/auto_docstring.py | 28 +++++++++++++++++++ 24 files changed, 28 insertions(+), 256 deletions(-) diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index 2a151a502d26..b2f819459fd0 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -1429,14 +1429,6 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -1684,14 +1676,6 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.text_config.output_router_logits diff --git a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py index a0569756ccbf..a701d4e006a9 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py @@ -1046,14 +1046,6 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) @@ -1193,14 +1185,6 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.text_config.output_router_logits diff --git a/src/transformers/models/glm46v/modeling_glm46v.py b/src/transformers/models/glm46v/modeling_glm46v.py index 46677bdd79da..a05ffe5d1890 100644 --- a/src/transformers/models/glm46v/modeling_glm46v.py +++ b/src/transformers/models/glm46v/modeling_glm46v.py @@ -291,14 +291,6 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames @@ -471,14 +463,6 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 7ac19671f4a7..c488db72ca67 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -1110,14 +1110,6 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames @@ -1290,14 +1282,6 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index e2510038c55d..56b5ae0b6a39 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -798,14 +798,6 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames @@ -1012,14 +1004,6 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 586cebd3f3a3..30e692a5b318 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -1279,14 +1279,6 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames @@ -1459,14 +1451,6 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py index a1fe288ad59b..fcbc4106d7e9 100644 --- a/src/transformers/models/glm_image/modeling_glm_image.py +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -1179,10 +1179,6 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) vision_outputs = self.visual( diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index 58696e869ed8..6ad5a61ea1c1 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -718,10 +718,6 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) vision_outputs = self.visual( diff --git a/src/transformers/models/glm_ocr/modeling_glm_ocr.py b/src/transformers/models/glm_ocr/modeling_glm_ocr.py index fd9fb53434ec..a170fabee6e7 100644 --- a/src/transformers/models/glm_ocr/modeling_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modeling_glm_ocr.py @@ -1027,14 +1027,6 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames @@ -1207,14 +1199,6 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index bd13ea077d2b..e12e58f2c828 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1722,10 +1722,6 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) return self.visual( @@ -1751,10 +1747,6 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) return self.visual( diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 4c26d2a1883d..a9febfd3a35b 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1761,10 +1761,6 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) return self.visual( diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index ee0e6aa239ef..eb15c13b8363 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1264,14 +1264,6 @@ def forward( The rope index difference between sequence length and multimodal rope. second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ if inputs_embeds is None: diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index fb47f4250e33..d9b27aafbb9c 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -508,14 +508,6 @@ def forward( The rope index difference between sequence length and multimodal rope. second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ if inputs_embeds is None: diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index c9ca20fc14c0..a806e531371a 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1405,10 +1405,6 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ # Same implementation as for images return self.get_image_features( @@ -1434,10 +1430,6 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithPooling = self.visual( @@ -1566,14 +1558,6 @@ def forward( The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index a910724b3fe6..3bded3a51c66 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -621,14 +621,6 @@ def forward( The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index ac2198d5878c..dd58f1a99acd 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1530,10 +1530,6 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ # Same implementation as for images return self.get_image_features( @@ -1559,10 +1555,6 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithPooling = self.visual( @@ -1691,14 +1683,6 @@ def forward( The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 55909fdd4481..df9b0228d0e9 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1878,10 +1878,6 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) return self.visual( @@ -1907,10 +1903,6 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) return self.visual( diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index 84ddb7d6e7be..3f0a0228ae92 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -1240,10 +1240,6 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) return self.visual( diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 206634dafbbe..6c8a4ba9fc61 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -1047,10 +1047,6 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ # Same implementation as for images return self.get_image_features( @@ -1076,10 +1072,6 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithDeepstackFeatures = self.visual( @@ -1212,14 +1204,6 @@ def forward( The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 8e2949d699fd..0bcab937bf0a 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -699,10 +699,6 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithDeepstackFeatures = self.visual( @@ -774,14 +770,6 @@ def forward( The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index ce64184920d6..68212c551468 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -1177,10 +1177,6 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ # Same implementation as for images return self.get_image_features( @@ -1206,10 +1202,6 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithDeepstackFeatures = self.visual( @@ -1342,14 +1334,6 @@ def forward( The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index de5866a080d1..17eb01dfc528 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -540,10 +540,6 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. video_merge_sizes (`torch.Tensor` of shape `(num_videos,)`): The spatial downsampling ratio of each video feature. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ return self.get_image_features( pixel_values=pixel_values_videos, @@ -572,10 +568,6 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. image_merge_sizes (`torch.Tensor` of shape `(num_images,)`): The spatial downsampling ratio of each image feature. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ vision_outputs = self.vision_model( pixel_values=pixel_values, diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index 8c2fc062c592..a58cfa465895 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -533,10 +533,6 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. image_merge_sizes (`torch.Tensor` of shape `(num_images,)`): The spatial downsampling ratio of each image feature. - image_cu_seqlens (`torch.Tensor` of shape `(num_image_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - image_rotary_pos_ids (`torch.Tensor` of shape `(num_image_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ vision_outputs = self.vision_model( pixel_values=pixel_values, diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py index bd04f3fb901e..fec3ea6728fc 100644 --- a/src/transformers/utils/auto_docstring.py +++ b/src/transformers/utils/auto_docstring.py @@ -2121,6 +2121,34 @@ class ModelArgs: "shape": "of shape `(batch_size, num_frames, num_channels, frame_size, frame_size)`", } + image_cu_seqlens = { + "description": """ + Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. + """, + "shape": "of shape `(num_image_patches + 1,)`", + } + + video_cu_seqlens = { + "description": """ + Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. + """, + "shape": "of shape `(num_video_patches + 1,)`", + } + + image_rotary_pos_ids = { + "description": """ + Precomputed (row, col) position IDs for image rotary embeddings. + """, + "shape": "of shape `(num_image_tokens, 2)`", + } + + video_rotary_pos_ids = { + "description": """ + Precomputed (row, col) position IDs for video rotary embeddings. + """, + "shape": "of shape `(num_video_tokens, 2)`", + } + vision_feature_layer = { "description": """ The index of the layer to select the vision feature. If multiple indices are provided, From 11f73fd4050aeafb0e76761b771e5e53fa344aa5 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 13:45:30 +0200 Subject: [PATCH 15/56] more docs fixing --- src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py | 4 ---- .../models/qwen3_omni_moe/modular_qwen3_omni_moe.py | 4 ---- src/transformers/models/qwen3_vl/modular_qwen3_vl.py | 4 ---- .../models/video_llama_3/modular_video_llama_3.py | 4 ---- 4 files changed, 16 deletions(-) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index a9febfd3a35b..84a0c8d872a9 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1732,10 +1732,6 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) return self.visual( diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index 3f0a0228ae92..fa2cbdb4e77d 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -1211,10 +1211,6 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) return self.visual( diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 0bcab937bf0a..c0d58de15485 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -731,10 +731,6 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ # Same implementation as for images return self.get_image_features( diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index a58cfa465895..deace870771a 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -501,10 +501,6 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. video_merge_sizes (`torch.Tensor` of shape `(num_videos,)`): The spatial downsampling ratio of each video feature. - video_cu_seqlens (`torch.Tensor` of shape `(num_video_patches + 1,)`, *optional*): - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - video_rotary_pos_ids (`torch.Tensor` of shape `(num_video_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ return self.get_image_features( pixel_values=pixel_values_videos, From 71f90eca838f7a3eaa4b8e7c97bd431022fe5e69 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 13:53:08 +0200 Subject: [PATCH 16/56] fix omni --- .../models/qwen2_5_omni/modeling_qwen2_5_omni.py | 13 ++++++------- .../models/qwen2_5_omni/modular_qwen2_5_omni.py | 13 ++++++------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index e12e58f2c828..59f77f7411b1 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -782,15 +782,14 @@ def get_pool_indices(feature_lens: torch.Tensor) -> torch.Tensor: feature_lens: ``(batch_size,)`` mel spectrogram lengths. Returns: - ``(total_pooled,)`` flat index of first element of each pair. + ``(total_pooled,)`` flat index of first element of each stride-2 pair. """ after_conv1 = (feature_lens - 1) // 2 + 1 - after_conv2 = (after_conv1 - 2) // 2 + 1 - num_pairs = after_conv2 // 2 - offsets = F.pad(after_conv2[:-1].cumsum(0), (1, 0), value=0) - pair_offsets = torch.repeat_interleave(offsets, num_pairs) - local_indices = torch.arange(num_pairs.sum(), device=feature_lens.device) - local_indices -= torch.repeat_interleave(F.pad(num_pairs[:-1].cumsum(0), (1, 0), value=0), num_pairs) + num_pooled = (after_conv1 - 2) // 2 + 1 + offsets = F.pad(after_conv1[:-1].cumsum(0), (1, 0), value=0) + pair_offsets = torch.repeat_interleave(offsets, num_pooled) + local_indices = torch.arange(num_pooled.sum(), device=feature_lens.device) + local_indices -= torch.repeat_interleave(F.pad(num_pooled[:-1].cumsum(0), (1, 0), value=0), num_pooled) return pair_offsets + local_indices * 2 diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 84a0c8d872a9..062546e263bf 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -132,15 +132,14 @@ def get_pool_indices(feature_lens: torch.Tensor) -> torch.Tensor: feature_lens: ``(batch_size,)`` mel spectrogram lengths. Returns: - ``(total_pooled,)`` flat index of first element of each pair. + ``(total_pooled,)`` flat index of first element of each stride-2 pair. """ after_conv1 = (feature_lens - 1) // 2 + 1 - after_conv2 = (after_conv1 - 2) // 2 + 1 - num_pairs = after_conv2 // 2 - offsets = F.pad(after_conv2[:-1].cumsum(0), (1, 0), value=0) - pair_offsets = torch.repeat_interleave(offsets, num_pairs) - local_indices = torch.arange(num_pairs.sum(), device=feature_lens.device) - local_indices -= torch.repeat_interleave(F.pad(num_pairs[:-1].cumsum(0), (1, 0), value=0), num_pairs) + num_pooled = (after_conv1 - 2) // 2 + 1 + offsets = F.pad(after_conv1[:-1].cumsum(0), (1, 0), value=0) + pair_offsets = torch.repeat_interleave(offsets, num_pooled) + local_indices = torch.arange(num_pooled.sum(), device=feature_lens.device) + local_indices -= torch.repeat_interleave(F.pad(num_pooled[:-1].cumsum(0), (1, 0), value=0), num_pooled) return pair_offsets + local_indices * 2 From a89d4369cd882e7d0035206e0ba967c42027ec52 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 14:58:17 +0200 Subject: [PATCH 17/56] fix paddle --- .../paddleocr_vl/modeling_paddleocr_vl.py | 81 +++++++++---------- .../paddleocr_vl/modular_paddleocr_vl.py | 81 +++++++++---------- 2 files changed, 76 insertions(+), 86 deletions(-) diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 38451c2074e0..97df06918065 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -42,6 +42,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check, torch_int from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults @@ -586,7 +587,7 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: def forward( self, pixel_values: torch.FloatTensor, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + image_grid_thw: torch.LongTensor | None = None, ) -> torch.Tensor: """ Args: @@ -819,9 +820,10 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, inputs_embeds: torch.FloatTensor, - cu_seqlens: torch.Tensor, + grid_thw: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: r""" @@ -829,36 +831,32 @@ def forward( Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`, *optional*): The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): - The temporal, height and width of feature shape of each image in LLM. + rotary_pos_ids (`torch.Tensor` of shape `(sequence_length,)`, *optional*): + Precomputed rotary position ids. If not provided, will be computed based on `grid_thw`. """ - device = inputs_embeds.device hidden_states = inputs_embeds attention_mask = create_bidirectional_mask( config=self.config, inputs_embeds=inputs_embeds, attention_mask=attention_mask, ) - split_hids = [] - split_wids = [] - for t, h, w in image_grid_thw: - image_pids = torch.arange(t * h * w, device=device) % (h * w) - sample_hids = image_pids // w - sample_wids = image_pids % w - split_hids.append(sample_hids) - split_wids.append(sample_wids) - width_position_ids = torch.concat(split_wids, dim=0) - height_position_ids = torch.concat(split_hids, dim=0) - - pids = torch.stack([height_position_ids, width_position_ids], dim=-1) + + if rotary_pos_ids is None: + pids = get_rotary_pos_ids(grid_thw, self.config.spatial_merge_size) + rotary_embeddings = self.rotary_pos_emb(pids) rotary_embeddings = rotary_embeddings.repeat(1, 2) position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin()) + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) + for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, @@ -897,29 +895,30 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, pixel_values: torch.FloatTensor, - cu_seqlens: torch.Tensor, + grid_thw: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size * patch_size * image_channels)`): The tensors corresponding to the input images. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): - The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`, *optional*): + Precomputed cumulative sequence lengths. """ - hidden_states = self.embeddings(pixel_values, image_grid_thw=image_grid_thw) - + hidden_states = self.embeddings(pixel_values, grid_thw=grid_thw) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, + grid_thw=grid_thw, cu_seqlens=cu_seqlens, attention_mask=attention_mask, - image_grid_thw=image_grid_thw, + rotary_pos_ids=rotary_pos_ids, **kwargs, ) @@ -948,23 +947,25 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, pixel_values: torch.FloatTensor, - cu_seqlens: torch.Tensor, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + grid_thw: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): The tensors corresponding to the input images. + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): The cumulative sequence lengths of each image or video feature. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): - The temporal, height and width of feature shape of each image in LLM. """ return self.vision_model( pixel_values=pixel_values, - cu_seqlens=cu_seqlens, - image_grid_thw=image_grid_thw, + image_grid_thw=grid_thw, + image_cu_seqlens=cu_seqlens, + image_rotary_pos_ids=rotary_pos_ids, **kwargs, ) @@ -1211,6 +1212,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1220,19 +1223,11 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype).unsqueeze(0) - cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=image_grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) vision_outputs = self.visual( pixel_values=pixel_values, image_grid_thw=image_grid_thw, - cu_seqlens=cu_seqlens, + image_cu_seqlens=image_cu_seqlens, + image_rotary_pos_ids=image_rotary_pos_ids, return_dict=True, **kwargs, ) diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index a3e1ec0dc4d4..cfd96f97b2f0 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -37,6 +37,7 @@ from ...masking_utils import create_bidirectional_mask, create_causal_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...models.qwen2_vl.image_processing_pil_qwen2_vl import Qwen2VLImageProcessorPil from ...models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor from ...processing_utils import ( @@ -729,7 +730,7 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: def forward( self, pixel_values: torch.FloatTensor, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + image_grid_thw: torch.LongTensor | None = None, ) -> torch.Tensor: """ Args: @@ -787,9 +788,10 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, inputs_embeds: torch.FloatTensor, - cu_seqlens: torch.Tensor, + grid_thw: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: r""" @@ -797,36 +799,32 @@ def forward( Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`, *optional*): The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): - The temporal, height and width of feature shape of each image in LLM. + rotary_pos_ids (`torch.Tensor` of shape `(sequence_length,)`, *optional*): + Precomputed rotary position ids. If not provided, will be computed based on `grid_thw`. """ - device = inputs_embeds.device hidden_states = inputs_embeds attention_mask = create_bidirectional_mask( config=self.config, inputs_embeds=inputs_embeds, attention_mask=attention_mask, ) - split_hids = [] - split_wids = [] - for t, h, w in image_grid_thw: - image_pids = torch.arange(t * h * w, device=device) % (h * w) - sample_hids = image_pids // w - sample_wids = image_pids % w - split_hids.append(sample_hids) - split_wids.append(sample_wids) - width_position_ids = torch.concat(split_wids, dim=0) - height_position_ids = torch.concat(split_hids, dim=0) - - pids = torch.stack([height_position_ids, width_position_ids], dim=-1) + + if rotary_pos_ids is None: + pids = get_rotary_pos_ids(grid_thw, self.config.spatial_merge_size) + rotary_embeddings = self.rotary_pos_emb(pids) rotary_embeddings = rotary_embeddings.repeat(1, 2) position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin()) + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) + for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, @@ -865,29 +863,30 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, pixel_values: torch.FloatTensor, - cu_seqlens: torch.Tensor, + grid_thw: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size * patch_size * image_channels)`): The tensors corresponding to the input images. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): - The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`, *optional*): + Precomputed cumulative sequence lengths. """ - hidden_states = self.embeddings(pixel_values, image_grid_thw=image_grid_thw) - + hidden_states = self.embeddings(pixel_values, grid_thw=grid_thw) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, + grid_thw=grid_thw, cu_seqlens=cu_seqlens, attention_mask=attention_mask, - image_grid_thw=image_grid_thw, + rotary_pos_ids=rotary_pos_ids, **kwargs, ) @@ -916,23 +915,25 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, pixel_values: torch.FloatTensor, - cu_seqlens: torch.Tensor, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + grid_thw: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): The tensors corresponding to the input images. + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): The cumulative sequence lengths of each image or video feature. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): - The temporal, height and width of feature shape of each image in LLM. """ return self.vision_model( pixel_values=pixel_values, - cu_seqlens=cu_seqlens, - image_grid_thw=image_grid_thw, + image_grid_thw=grid_thw, + image_cu_seqlens=cu_seqlens, + image_rotary_pos_ids=rotary_pos_ids, **kwargs, ) @@ -972,6 +973,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -981,19 +984,11 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype).unsqueeze(0) - cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=image_grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) vision_outputs = self.visual( pixel_values=pixel_values, image_grid_thw=image_grid_thw, - cu_seqlens=cu_seqlens, + image_cu_seqlens=image_cu_seqlens, + image_rotary_pos_ids=image_rotary_pos_ids, return_dict=True, **kwargs, ) From c0fdc0da10ccde0d42c9ba879ea9ad820c260d0f Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 15:36:41 +0200 Subject: [PATCH 18/56] revert paddle ocr until another time --- .../paddleocr_vl/modeling_paddleocr_vl.py | 105 +++++++++--------- .../paddleocr_vl/modular_paddleocr_vl.py | 87 ++++++++------- 2 files changed, 98 insertions(+), 94 deletions(-) diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 97df06918065..8ed3be0ad4be 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -42,7 +42,6 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check, torch_int from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults @@ -99,8 +98,10 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: - return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs class PaddleOCRRotaryEmbedding(nn.Module): @@ -587,7 +588,7 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: def forward( self, pixel_values: torch.FloatTensor, - image_grid_thw: torch.LongTensor | None = None, + image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, ) -> torch.Tensor: """ Args: @@ -820,10 +821,9 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, inputs_embeds: torch.FloatTensor, - grid_thw: torch.LongTensor | None = None, - cu_seqlens: torch.Tensor | None = None, + cu_seqlens: torch.Tensor, attention_mask: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: r""" @@ -831,32 +831,38 @@ def forward( Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): - The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`, *optional*): + cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - rotary_pos_ids (`torch.Tensor` of shape `(sequence_length,)`, *optional*): - Precomputed rotary position ids. If not provided, will be computed based on `grid_thw`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. """ + device = inputs_embeds.device hidden_states = inputs_embeds attention_mask = create_bidirectional_mask( config=self.config, inputs_embeds=inputs_embeds, attention_mask=attention_mask, ) - - if rotary_pos_ids is None: - pids = get_rotary_pos_ids(grid_thw, self.config.spatial_merge_size) - - rotary_embeddings = self.rotary_pos_emb(pids) + split_hids = [] + split_wids = [] + for t, h, w in image_grid_thw: + image_pids = torch.arange(t * h * w, device=device) % (h * w) + sample_hids = image_pids // w + sample_wids = image_pids % w + split_hids.append(sample_hids) + split_wids.append(sample_wids) + width_position_ids = torch.concat(split_wids, dim=0) + height_position_ids = torch.concat(split_hids, dim=0) + + pids = torch.stack([height_position_ids, width_position_ids], dim=-1) + max_grid_size = pids.max() + 1 + rotary_embeddings_max_grid = self.rotary_pos_emb(max_grid_size) + rotary_embeddings = rotary_embeddings_max_grid[pids].flatten(1) rotary_embeddings = rotary_embeddings.repeat(1, 2) position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin()) - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, @@ -895,30 +901,29 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, pixel_values: torch.FloatTensor, - grid_thw: torch.LongTensor | None = None, - cu_seqlens: torch.Tensor | None = None, + cu_seqlens: torch.Tensor, attention_mask: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size * patch_size * image_channels)`): The tensors corresponding to the input images. + cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): + The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`, *optional*): - Precomputed cumulative sequence lengths. """ - hidden_states = self.embeddings(pixel_values, grid_thw=grid_thw) + hidden_states = self.embeddings(pixel_values, image_grid_thw=image_grid_thw) + encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, - grid_thw=grid_thw, cu_seqlens=cu_seqlens, attention_mask=attention_mask, - rotary_pos_ids=rotary_pos_ids, + image_grid_thw=image_grid_thw, **kwargs, ) @@ -947,25 +952,23 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, pixel_values: torch.FloatTensor, - grid_thw: torch.LongTensor | None = None, - cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + cu_seqlens: torch.Tensor, + image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): The tensors corresponding to the input images. - grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): - The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): The cumulative sequence lengths of each image or video feature. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. """ return self.vision_model( pixel_values=pixel_values, - image_grid_thw=grid_thw, - image_cu_seqlens=cu_seqlens, - image_rotary_pos_ids=rotary_pos_ids, + cu_seqlens=cu_seqlens, + image_grid_thw=image_grid_thw, **kwargs, ) @@ -1212,8 +1215,6 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1223,11 +1224,19 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype).unsqueeze(0) + cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=image_grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) vision_outputs = self.visual( pixel_values=pixel_values, image_grid_thw=image_grid_thw, - image_cu_seqlens=image_cu_seqlens, - image_rotary_pos_ids=image_rotary_pos_ids, + cu_seqlens=cu_seqlens, return_dict=True, **kwargs, ) @@ -1393,8 +1402,6 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1402,18 +1409,8 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). - image_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ - return self.model.get_image_features( - pixel_values, - image_grid_thw, - image_cu_seqlens, - image_rotary_pos_ids, - **kwargs, - ) + return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index cfd96f97b2f0..12d935978415 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -37,7 +37,6 @@ from ...masking_utils import create_bidirectional_mask, create_causal_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...models.qwen2_vl.image_processing_pil_qwen2_vl import Qwen2VLImageProcessorPil from ...models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor from ...processing_utils import ( @@ -534,7 +533,7 @@ class PaddleOCRVLConfig(Qwen2VLConfig): video_token_id: int = 100296 vision_start_token_id: int = 101305 vision_end_token_id: int = 101306 - tie_word_embeddings: bool = True + tie_word_embeddings: int = True class PaddleOCRProjector(nn.Module): @@ -730,7 +729,7 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: def forward( self, pixel_values: torch.FloatTensor, - image_grid_thw: torch.LongTensor | None = None, + image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, ) -> torch.Tensor: """ Args: @@ -788,10 +787,9 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, inputs_embeds: torch.FloatTensor, - grid_thw: torch.LongTensor | None = None, - cu_seqlens: torch.Tensor | None = None, + cu_seqlens: torch.Tensor, attention_mask: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: r""" @@ -799,32 +797,38 @@ def forward( Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): - The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`, *optional*): + cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - rotary_pos_ids (`torch.Tensor` of shape `(sequence_length,)`, *optional*): - Precomputed rotary position ids. If not provided, will be computed based on `grid_thw`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. """ + device = inputs_embeds.device hidden_states = inputs_embeds attention_mask = create_bidirectional_mask( config=self.config, inputs_embeds=inputs_embeds, attention_mask=attention_mask, ) - - if rotary_pos_ids is None: - pids = get_rotary_pos_ids(grid_thw, self.config.spatial_merge_size) - - rotary_embeddings = self.rotary_pos_emb(pids) + split_hids = [] + split_wids = [] + for t, h, w in image_grid_thw: + image_pids = torch.arange(t * h * w, device=device) % (h * w) + sample_hids = image_pids // w + sample_wids = image_pids % w + split_hids.append(sample_hids) + split_wids.append(sample_wids) + width_position_ids = torch.concat(split_wids, dim=0) + height_position_ids = torch.concat(split_hids, dim=0) + + pids = torch.stack([height_position_ids, width_position_ids], dim=-1) + max_grid_size = pids.max() + 1 + rotary_embeddings_max_grid = self.rotary_pos_emb(max_grid_size) + rotary_embeddings = rotary_embeddings_max_grid[pids].flatten(1) rotary_embeddings = rotary_embeddings.repeat(1, 2) position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin()) - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, @@ -863,30 +867,29 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, pixel_values: torch.FloatTensor, - grid_thw: torch.LongTensor | None = None, - cu_seqlens: torch.Tensor | None = None, + cu_seqlens: torch.Tensor, attention_mask: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size * patch_size * image_channels)`): The tensors corresponding to the input images. + cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): + The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`, *optional*): - Precomputed cumulative sequence lengths. """ - hidden_states = self.embeddings(pixel_values, grid_thw=grid_thw) + hidden_states = self.embeddings(pixel_values, image_grid_thw=image_grid_thw) + encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, - grid_thw=grid_thw, cu_seqlens=cu_seqlens, attention_mask=attention_mask, - rotary_pos_ids=rotary_pos_ids, + image_grid_thw=image_grid_thw, **kwargs, ) @@ -915,25 +918,23 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, pixel_values: torch.FloatTensor, - grid_thw: torch.LongTensor | None = None, - cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + cu_seqlens: torch.Tensor, + image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): The tensors corresponding to the input images. - grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): - The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): The cumulative sequence lengths of each image or video feature. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. """ return self.vision_model( pixel_values=pixel_values, - image_grid_thw=grid_thw, - image_cu_seqlens=cu_seqlens, - image_rotary_pos_ids=rotary_pos_ids, + cu_seqlens=cu_seqlens, + image_grid_thw=image_grid_thw, **kwargs, ) @@ -973,8 +974,6 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -984,11 +983,19 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype).unsqueeze(0) + cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=image_grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) vision_outputs = self.visual( pixel_values=pixel_values, image_grid_thw=image_grid_thw, - image_cu_seqlens=image_cu_seqlens, - image_rotary_pos_ids=image_rotary_pos_ids, + cu_seqlens=cu_seqlens, return_dict=True, **kwargs, ) From d1da022942b056f8e2c22976ca00665cb939d382 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 15:48:38 +0200 Subject: [PATCH 19/56] finally fixed paddle ocr --- .../paddleocr_vl/modeling_paddleocr_vl.py | 110 +++++++++--------- .../paddleocr_vl/modular_paddleocr_vl.py | 92 +++++++-------- 2 files changed, 98 insertions(+), 104 deletions(-) diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 8ed3be0ad4be..99f7167dae2f 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -42,6 +42,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check, torch_int from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults @@ -98,10 +99,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: + return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class PaddleOCRRotaryEmbedding(nn.Module): @@ -588,13 +587,13 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: def forward( self, pixel_values: torch.FloatTensor, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + grid_thw: torch.LongTensor | None = None, ) -> torch.Tensor: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): The tensors corresponding to the input images. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ batch_size, squence_len, channel, height, width = pixel_values.shape @@ -607,8 +606,7 @@ def forward( start = 0 embeddings = embeddings.squeeze(0) tmp_embeddings = [] - for image_grid in image_grid_thw: - t, h, w = image_grid + for t, h, w in grid_thw: end = start + t * h * w image_embeddings = embeddings[start:end, :] position_embedding = self.interpolate_pos_encoding(image_embeddings, h, w).squeeze(0).repeat(t, 1) @@ -821,9 +819,10 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, inputs_embeds: torch.FloatTensor, - cu_seqlens: torch.Tensor, + grid_thw: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: r""" @@ -831,38 +830,32 @@ def forward( Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`, *optional*): The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): - The temporal, height and width of feature shape of each image in LLM. + rotary_pos_ids (`torch.Tensor` of shape `(sequence_length,)`, *optional*): + Precomputed rotary position ids. If not provided, will be computed based on `image_grid_thw`. """ - device = inputs_embeds.device hidden_states = inputs_embeds attention_mask = create_bidirectional_mask( config=self.config, inputs_embeds=inputs_embeds, attention_mask=attention_mask, ) - split_hids = [] - split_wids = [] - for t, h, w in image_grid_thw: - image_pids = torch.arange(t * h * w, device=device) % (h * w) - sample_hids = image_pids // w - sample_wids = image_pids % w - split_hids.append(sample_hids) - split_wids.append(sample_wids) - width_position_ids = torch.concat(split_wids, dim=0) - height_position_ids = torch.concat(split_hids, dim=0) - - pids = torch.stack([height_position_ids, width_position_ids], dim=-1) - max_grid_size = pids.max() + 1 - rotary_embeddings_max_grid = self.rotary_pos_emb(max_grid_size) - rotary_embeddings = rotary_embeddings_max_grid[pids].flatten(1) + + if rotary_pos_ids is None: + pids = get_rotary_pos_ids(grid_thw, self.config.spatial_merge_size) + + rotary_embeddings = self.rotary_pos_emb(pids) rotary_embeddings = rotary_embeddings.repeat(1, 2) position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin()) + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) + for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, @@ -901,29 +894,30 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, pixel_values: torch.FloatTensor, - cu_seqlens: torch.Tensor, + grid_thw: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size * patch_size * image_channels)`): The tensors corresponding to the input images. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): - The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`, *optional*): + Precomputed cumulative sequence lengths. """ - hidden_states = self.embeddings(pixel_values, image_grid_thw=image_grid_thw) - + hidden_states = self.embeddings(pixel_values, grid_thw=grid_thw) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, + grid_thw=grid_thw, cu_seqlens=cu_seqlens, attention_mask=attention_mask, - image_grid_thw=image_grid_thw, + rotary_pos_ids=rotary_pos_ids, **kwargs, ) @@ -952,23 +946,25 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, pixel_values: torch.FloatTensor, - cu_seqlens: torch.Tensor, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + grid_thw: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): The tensors corresponding to the input images. + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): The cumulative sequence lengths of each image or video feature. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): - The temporal, height and width of feature shape of each image in LLM. """ return self.vision_model( pixel_values=pixel_values, + grid_thw=grid_thw, cu_seqlens=cu_seqlens, - image_grid_thw=image_grid_thw, + rotary_pos_ids=rotary_pos_ids, **kwargs, ) @@ -1215,6 +1211,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1224,19 +1222,11 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype).unsqueeze(0) - cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=image_grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) vision_outputs = self.visual( pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - cu_seqlens=cu_seqlens, + grid_thw=image_grid_thw, + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, return_dict=True, **kwargs, ) @@ -1402,6 +1392,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1409,8 +1401,18 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + image_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features( + pixel_values, + image_grid_thw, + image_cu_seqlens, + image_rotary_pos_ids, + **kwargs, + ) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index 12d935978415..edb80f1019f3 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -37,6 +37,7 @@ from ...masking_utils import create_bidirectional_mask, create_causal_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel +from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...models.qwen2_vl.image_processing_pil_qwen2_vl import Qwen2VLImageProcessorPil from ...models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor from ...processing_utils import ( @@ -533,7 +534,7 @@ class PaddleOCRVLConfig(Qwen2VLConfig): video_token_id: int = 100296 vision_start_token_id: int = 101305 vision_end_token_id: int = 101306 - tie_word_embeddings: int = True + tie_word_embeddings: bool = True class PaddleOCRProjector(nn.Module): @@ -729,13 +730,13 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: def forward( self, pixel_values: torch.FloatTensor, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + grid_thw: torch.LongTensor | None = None, ) -> torch.Tensor: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): The tensors corresponding to the input images. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ batch_size, squence_len, channel, height, width = pixel_values.shape @@ -748,8 +749,7 @@ def forward( start = 0 embeddings = embeddings.squeeze(0) tmp_embeddings = [] - for image_grid in image_grid_thw: - t, h, w = image_grid + for t, h, w in grid_thw: end = start + t * h * w image_embeddings = embeddings[start:end, :] position_embedding = self.interpolate_pos_encoding(image_embeddings, h, w).squeeze(0).repeat(t, 1) @@ -787,9 +787,10 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, inputs_embeds: torch.FloatTensor, - cu_seqlens: torch.Tensor, + grid_thw: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: r""" @@ -797,38 +798,32 @@ def forward( Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`, *optional*): The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): - The temporal, height and width of feature shape of each image in LLM. + rotary_pos_ids (`torch.Tensor` of shape `(sequence_length,)`, *optional*): + Precomputed rotary position ids. If not provided, will be computed based on `image_grid_thw`. """ - device = inputs_embeds.device hidden_states = inputs_embeds attention_mask = create_bidirectional_mask( config=self.config, inputs_embeds=inputs_embeds, attention_mask=attention_mask, ) - split_hids = [] - split_wids = [] - for t, h, w in image_grid_thw: - image_pids = torch.arange(t * h * w, device=device) % (h * w) - sample_hids = image_pids // w - sample_wids = image_pids % w - split_hids.append(sample_hids) - split_wids.append(sample_wids) - width_position_ids = torch.concat(split_wids, dim=0) - height_position_ids = torch.concat(split_hids, dim=0) - - pids = torch.stack([height_position_ids, width_position_ids], dim=-1) - max_grid_size = pids.max() + 1 - rotary_embeddings_max_grid = self.rotary_pos_emb(max_grid_size) - rotary_embeddings = rotary_embeddings_max_grid[pids].flatten(1) + + if rotary_pos_ids is None: + pids = get_rotary_pos_ids(grid_thw, self.config.spatial_merge_size) + + rotary_embeddings = self.rotary_pos_emb(pids) rotary_embeddings = rotary_embeddings.repeat(1, 2) position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin()) + if cu_seqlens is None: + cu_seqlens = get_vision_cu_seqlens(grid_thw) + for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, @@ -867,29 +862,30 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, pixel_values: torch.FloatTensor, - cu_seqlens: torch.Tensor, + grid_thw: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size * patch_size * image_channels)`): The tensors corresponding to the input images. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): - The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`, *optional*): + Precomputed cumulative sequence lengths. """ - hidden_states = self.embeddings(pixel_values, image_grid_thw=image_grid_thw) - + hidden_states = self.embeddings(pixel_values, grid_thw=grid_thw) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, + grid_thw=grid_thw, cu_seqlens=cu_seqlens, attention_mask=attention_mask, - image_grid_thw=image_grid_thw, + rotary_pos_ids=rotary_pos_ids, **kwargs, ) @@ -918,23 +914,25 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, pixel_values: torch.FloatTensor, - cu_seqlens: torch.Tensor, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + grid_thw: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, + rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): The tensors corresponding to the input images. + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): The cumulative sequence lengths of each image or video feature. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): - The temporal, height and width of feature shape of each image in LLM. """ return self.vision_model( pixel_values=pixel_values, + grid_thw=grid_thw, cu_seqlens=cu_seqlens, - image_grid_thw=image_grid_thw, + rotary_pos_ids=rotary_pos_ids, **kwargs, ) @@ -974,6 +972,8 @@ def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -983,19 +983,11 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype).unsqueeze(0) - cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=image_grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) vision_outputs = self.visual( pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - cu_seqlens=cu_seqlens, + grid_thw=image_grid_thw, + cu_seqlens=image_cu_seqlens, + rotary_pos_ids=image_rotary_pos_ids, return_dict=True, **kwargs, ) From 448ff2eed04d9ab4f5d254d39e9d30665e359533 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 16:04:46 +0200 Subject: [PATCH 20/56] fix review --- src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py | 4 ++-- src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py | 4 ++-- src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py | 3 --- src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py | 3 --- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 99f7167dae2f..757edf82cd6f 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -847,9 +847,9 @@ def forward( ) if rotary_pos_ids is None: - pids = get_rotary_pos_ids(grid_thw, self.config.spatial_merge_size) + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.config.spatial_merge_size) - rotary_embeddings = self.rotary_pos_emb(pids) + rotary_embeddings = self.rotary_pos_emb(rotary_pos_ids) rotary_embeddings = rotary_embeddings.repeat(1, 2) position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin()) diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index edb80f1019f3..c55ce84acbbc 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -815,9 +815,9 @@ def forward( ) if rotary_pos_ids is None: - pids = get_rotary_pos_ids(grid_thw, self.config.spatial_merge_size) + rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.config.spatial_merge_size) - rotary_embeddings = self.rotary_pos_emb(pids) + rotary_embeddings = self.rotary_pos_emb(rotary_pos_ids) rotary_embeddings = rotary_embeddings.repeat(1, 2) position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin()) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 59f77f7411b1..eb595e20f208 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -882,9 +882,6 @@ def forward( if pool_indices is None: pool_indices = get_pool_indices(feature_lens) - if aftercnn_lens is None and feature_lens is not None: - aftercnn_lens, _ = self._get_feat_extract_output_lengths(feature_lens) - # Derive masks from chunk_lengths (traceable arithmetic + arange broadcasting) padded_feature = padded_feature.to(self.conv1.weight.dtype) padded_mask = ( diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 062546e263bf..0a03f26ea67f 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1364,9 +1364,6 @@ def forward( if pool_indices is None: pool_indices = get_pool_indices(feature_lens) - if aftercnn_lens is None and feature_lens is not None: - aftercnn_lens, _ = self._get_feat_extract_output_lengths(feature_lens) - # Derive masks from chunk_lengths (traceable arithmetic + arange broadcasting) padded_feature = padded_feature.to(self.conv1.weight.dtype) padded_mask = ( From 6731028e84486874a6b2dfd90402763ec8d72414 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 16:10:18 +0200 Subject: [PATCH 21/56] revert chunking --- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 13 +++++++++---- .../models/qwen3_omni_moe/modular_qwen3_omni_moe.py | 13 +++++++++---- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index df9b0228d0e9..71d6a2cf38ed 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -782,10 +782,15 @@ def forward( cu_seqlens = get_audio_cu_seqlens(chunk_lengths, feature_lens, self.n_window_infer, self.n_window) # Add channel dim for Conv2d: (num_chunks, mel_bins, time) -> (num_chunks, 1, mel_bins, time) - padded_feature = padded_feature.unsqueeze(1) - padded_embed = F.gelu(self.conv2d1(padded_feature)) - padded_embed = F.gelu(self.conv2d2(padded_embed)) - padded_embed = F.gelu(self.conv2d3(padded_embed)) + # Split to chunk to avoid OOM during convolution + padded_embeds = [] + for chunk in padded_feature.split(self.conv_chunksize, dim=0): + padded_embed = F.gelu(self.conv2d1(chunk)) + padded_embed = F.gelu(self.conv2d2(padded_embed)) + padded_embed = F.gelu(self.conv2d3(padded_embed)) + padded_embeds.append(padded_embed) + padded_embed = torch.cat(padded_embeds, dim=0) + b, c, f, t = padded_embed.size() padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)) diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index fa2cbdb4e77d..c0eeff67b989 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -1011,10 +1011,15 @@ def forward( cu_seqlens = get_audio_cu_seqlens(chunk_lengths, feature_lens, self.n_window_infer, self.n_window) # Add channel dim for Conv2d: (num_chunks, mel_bins, time) -> (num_chunks, 1, mel_bins, time) - padded_feature = padded_feature.unsqueeze(1) - padded_embed = F.gelu(self.conv2d1(padded_feature)) - padded_embed = F.gelu(self.conv2d2(padded_embed)) - padded_embed = F.gelu(self.conv2d3(padded_embed)) + # Split to chunk to avoid OOM during convolution + padded_embeds = [] + for chunk in padded_feature.split(self.conv_chunksize, dim=0): + padded_embed = F.gelu(self.conv2d1(chunk)) + padded_embed = F.gelu(self.conv2d2(padded_embed)) + padded_embed = F.gelu(self.conv2d3(padded_embed)) + padded_embeds.append(padded_embed) + padded_embed = torch.cat(padded_embeds, dim=0) + b, c, f, t = padded_embed.size() padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)) From 693ba9cf2131d3c23178f41a9bd19ea57a4a4971 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Mon, 13 Apr 2026 16:22:41 +0200 Subject: [PATCH 22/56] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index c55ce84acbbc..5bdb8b629e9d 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -804,8 +804,9 @@ def forward( The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - rotary_pos_ids (`torch.Tensor` of shape `(sequence_length,)`, *optional*): - Precomputed rotary position ids. If not provided, will be computed based on `image_grid_thw`. + rotary_pos_ids (`torch.Tensor` of shape `(sequence_length, 2)`, *optional*): + Precomputed rotary position ids as `(row, column)` pairs. If not provided, will be computed based on + `image_grid_thw`. """ hidden_states = inputs_embeds attention_mask = create_bidirectional_mask( From d701016f57471803c5fa2d236c72be077de5bacb Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Mon, 13 Apr 2026 16:22:54 +0200 Subject: [PATCH 23/56] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- .../models/paddleocr_vl/modeling_paddleocr_vl.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 757edf82cd6f..018b47fff3a7 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -836,8 +836,9 @@ def forward( The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - rotary_pos_ids (`torch.Tensor` of shape `(sequence_length,)`, *optional*): - Precomputed rotary position ids. If not provided, will be computed based on `image_grid_thw`. + rotary_pos_ids (`torch.Tensor` of shape `(sequence_length, 2)`, *optional*): + Precomputed rotary position ids for the row and column coordinates. If not provided, will be computed + based on `image_grid_thw`. """ hidden_states = inputs_embeds attention_mask = create_bidirectional_mask( From 5472c4f351190589dac7e8fa6bebd4e265eec2da Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 16:35:00 +0200 Subject: [PATCH 24/56] fix torch compilable check --- src/transformers/utils/import_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 1e1ac2545f05..d8d09351995b 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1522,6 +1522,16 @@ def torch_compilable_check(cond: Any, msg: str | Callable[[], str], error_type: import torch + # When tracing, msg may be an f-string with tensor values that dynamo can't trace + # (callable/isinstance on it breaks). Check compilation first and use torch._check + # without msg (it only serves as a compiler hint in that case). + if is_tracing(): + if isinstance(cond, torch.Tensor): + torch._check_tensor_all(cond) + else: + torch._check(cond) + return + if not callable(msg): # torch._check requires msg to be a callable but we want to keep the API simple for users def msg_callable(): From 4e7739bf7f1e4af733d655e6af5c0282b53bd327 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 16:47:48 +0200 Subject: [PATCH 25/56] fix docs --- src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 018b47fff3a7..640d69b9fd89 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -837,8 +837,8 @@ def forward( attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. rotary_pos_ids (`torch.Tensor` of shape `(sequence_length, 2)`, *optional*): - Precomputed rotary position ids for the row and column coordinates. If not provided, will be computed - based on `image_grid_thw`. + Precomputed rotary position ids as `(row, column)` pairs. If not provided, will be computed based on + `image_grid_thw`. """ hidden_states = inputs_embeds attention_mask = create_bidirectional_mask( From 47fed92c34ac21677cc3b0e3b38101db94994c02 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 16:48:45 +0200 Subject: [PATCH 26/56] correct func name --- .../models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py | 4 ++-- src/transformers/models/glm46v/modeling_glm46v.py | 4 ++-- src/transformers/models/glm4v/modeling_glm4v.py | 4 ++-- src/transformers/models/glm4v_moe/modeling_glm4v_moe.py | 4 ++-- src/transformers/models/glm_ocr/modeling_glm_ocr.py | 4 ++-- .../models/paddleocr_vl/modeling_paddleocr_vl.py | 2 +- src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 4 ++-- src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 8 ++++---- src/transformers/models/qwen3_5/modeling_qwen3_5.py | 4 ++-- .../models/qwen3_5_moe/modeling_qwen3_5_moe.py | 4 ++-- src/transformers/models/qwen3_vl/modeling_qwen3_vl.py | 4 ++-- .../models/qwen3_vl_moe/modeling_qwen3_vl_moe.py | 4 ++-- .../models/video_llama_3/modeling_video_llama_3.py | 4 ++-- 13 files changed, 27 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index b2f819459fd0..6c4444060f2f 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -1600,7 +1600,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). video_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ @@ -1627,7 +1627,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). image_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ diff --git a/src/transformers/models/glm46v/modeling_glm46v.py b/src/transformers/models/glm46v/modeling_glm46v.py index a05ffe5d1890..b8a50312c5aa 100644 --- a/src/transformers/models/glm46v/modeling_glm46v.py +++ b/src/transformers/models/glm46v/modeling_glm46v.py @@ -570,7 +570,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). video_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ @@ -597,7 +597,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). image_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index c488db72ca67..92f9dc7e610c 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -1389,7 +1389,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). video_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ @@ -1416,7 +1416,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). image_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 30e692a5b318..63f01eed7743 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -1613,7 +1613,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). video_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ @@ -1640,7 +1640,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). image_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ diff --git a/src/transformers/models/glm_ocr/modeling_glm_ocr.py b/src/transformers/models/glm_ocr/modeling_glm_ocr.py index a170fabee6e7..e3ebce52d96a 100644 --- a/src/transformers/models/glm_ocr/modeling_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modeling_glm_ocr.py @@ -1306,7 +1306,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). video_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ @@ -1333,7 +1333,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). image_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 640d69b9fd89..227610f26564 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -1403,7 +1403,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). image_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index eb15c13b8363..3c643886a56d 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1388,7 +1388,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). video_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ @@ -1415,7 +1415,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). image_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index e88ad7497e0b..7fc1149e9672 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1350,7 +1350,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). video_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ @@ -1377,7 +1377,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). image_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ @@ -1425,11 +1425,11 @@ def forward( rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). image_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). video_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index a806e531371a..78b72f410f60 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1761,7 +1761,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). video_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ @@ -1788,7 +1788,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). image_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index dd58f1a99acd..25fd6ca36f02 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1960,7 +1960,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). video_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ @@ -1987,7 +1987,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). image_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 6c8a4ba9fc61..8630e2027fb8 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -1361,7 +1361,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). video_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ @@ -1388,7 +1388,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). image_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 68212c551468..30d5e3183dff 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -1544,7 +1544,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). video_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ @@ -1571,7 +1571,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). image_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index 17eb01dfc528..825300d80125 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -773,7 +773,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). video_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ @@ -800,7 +800,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_cu_seqlens`). + Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). image_rotary_pos_ids (`torch.LongTensor`, *optional*): Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ From 18a1788964193806f7edf84d146ec3df6eceb11e Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 17:00:44 +0200 Subject: [PATCH 27/56] fix omni --- .../models/qwen3_omni_moe/modeling_qwen3_omni_moe.py | 1 + src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 71d6a2cf38ed..d071d0bbd8a4 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -782,6 +782,7 @@ def forward( cu_seqlens = get_audio_cu_seqlens(chunk_lengths, feature_lens, self.n_window_infer, self.n_window) # Add channel dim for Conv2d: (num_chunks, mel_bins, time) -> (num_chunks, 1, mel_bins, time) + padded_feature = padded_feature.unsqueeze(1) # Split to chunk to avoid OOM during convolution padded_embeds = [] for chunk in padded_feature.split(self.conv_chunksize, dim=0): diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index c0eeff67b989..6b4658e73682 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -1011,6 +1011,7 @@ def forward( cu_seqlens = get_audio_cu_seqlens(chunk_lengths, feature_lens, self.n_window_infer, self.n_window) # Add channel dim for Conv2d: (num_chunks, mel_bins, time) -> (num_chunks, 1, mel_bins, time) + padded_feature = padded_feature.unsqueeze(1) # Split to chunk to avoid OOM during convolution padded_embeds = [] for chunk in padded_feature.split(self.conv_chunksize, dim=0): From 4c6e1dfea0e1e34a468f73c76c6649e327e861a1 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 17:26:12 +0200 Subject: [PATCH 28/56] fix video llama 3 --- .../video_llama_3/modeling_video_llama_3.py | 26 +++++--- .../video_llama_3/modular_video_llama_3.py | 64 +++++++++++++++++++ 2 files changed, 82 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index 825300d80125..5e5595472001 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -758,11 +758,13 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_merge_sizes: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, video_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], @@ -772,24 +774,29 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_merge_sizes (`torch.Tensor` of shape `(num_videos,)`, *optional*): + The spatial downsampling ratio of each video feature. video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). + Precomputed cumulative sequence lengths for videos. video_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). + Precomputed (row, col) position IDs for video rotary embeddings. """ return self.model.get_video_features( pixel_values_videos, video_grid_thw, - video_cu_seqlens, - video_rotary_pos_ids, + video_merge_sizes, + video_cu_seqlens=video_cu_seqlens, + video_rotary_pos_ids=video_rotary_pos_ids, **kwargs, ) + @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_merge_sizes: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, image_rotary_pos_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], @@ -799,16 +806,19 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_merge_sizes (`torch.Tensor` of shape `(num_images,)`, *optional*): + The spatial downsampling ratio of each image feature. image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). + Precomputed cumulative sequence lengths for images. image_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). + Precomputed (row, col) position IDs for image rotary embeddings. """ return self.model.get_image_features( pixel_values, image_grid_thw, - image_cu_seqlens, - image_rotary_pos_ids, + image_merge_sizes, + image_cu_seqlens=image_cu_seqlens, + image_rotary_pos_ids=image_rotary_pos_ids, **kwargs, ) diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index deace870771a..3045be57bc74 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -667,6 +667,70 @@ class VideoLlama3ForConditionalGeneration(Qwen2VLForConditionalGeneration): def __init__(self, config: VideoLlama3Config): super().__init__(config) # just to add type hint on config + @can_return_tuple + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: torch.LongTensor | None = None, + image_merge_sizes: torch.LongTensor | None = None, + image_cu_seqlens: torch.Tensor | None = None, + image_rotary_pos_ids: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + image_merge_sizes (`torch.Tensor` of shape `(num_images,)`, *optional*): + The spatial downsampling ratio of each image feature. + image_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for images. + image_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for image rotary embeddings. + """ + return self.model.get_image_features( + pixel_values, + image_grid_thw, + image_merge_sizes, + image_cu_seqlens=image_cu_seqlens, + image_rotary_pos_ids=image_rotary_pos_ids, + **kwargs, + ) + + @can_return_tuple + @auto_docstring + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_grid_thw: torch.LongTensor | None = None, + video_merge_sizes: torch.LongTensor | None = None, + video_cu_seqlens: torch.Tensor | None = None, + video_rotary_pos_ids: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + video_merge_sizes (`torch.Tensor` of shape `(num_videos,)`, *optional*): + The spatial downsampling ratio of each video feature. + video_cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths for videos. + video_rotary_pos_ids (`torch.LongTensor`, *optional*): + Precomputed (row, col) position IDs for video rotary embeddings. + """ + return self.model.get_video_features( + pixel_values_videos, + video_grid_thw, + video_merge_sizes, + video_cu_seqlens=video_cu_seqlens, + video_rotary_pos_ids=video_rotary_pos_ids, + **kwargs, + ) + @can_return_tuple @auto_docstring def forward( From 247b445a1a38b5a1b376a2565a2c7b9b510bb651 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 13 Apr 2026 17:36:48 +0200 Subject: [PATCH 29/56] fix video llama 3 --- .../models/video_llama_3/modeling_video_llama_3.py | 8 -------- .../models/video_llama_3/modular_video_llama_3.py | 8 -------- 2 files changed, 16 deletions(-) diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index 5e5595472001..c226765acaf3 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -776,10 +776,6 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. video_merge_sizes (`torch.Tensor` of shape `(num_videos,)`, *optional*): The spatial downsampling ratio of each video feature. - video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos. - video_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ return self.model.get_video_features( pixel_values_videos, @@ -808,10 +804,6 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. image_merge_sizes (`torch.Tensor` of shape `(num_images,)`, *optional*): The spatial downsampling ratio of each image feature. - image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images. - image_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ return self.model.get_image_features( pixel_values, diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index 3045be57bc74..71ff165a53c9 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -685,10 +685,6 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. image_merge_sizes (`torch.Tensor` of shape `(num_images,)`, *optional*): The spatial downsampling ratio of each image feature. - image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images. - image_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings. """ return self.model.get_image_features( pixel_values, @@ -717,10 +713,6 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. video_merge_sizes (`torch.Tensor` of shape `(num_videos,)`, *optional*): The spatial downsampling ratio of each video feature. - video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos. - video_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings. """ return self.model.get_video_features( pixel_values_videos, From 3c5e9a8a14a7eb502ffa6d4909d8611fbf1844f6 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 14 Apr 2026 08:40:30 +0200 Subject: [PATCH 30/56] requires torch --- src/transformers/modeling_vision_utils.py | 9 +++++---- src/transformers/models/glm46v/processing_glm46v.py | 3 +++ src/transformers/models/glm4v/modular_glm4v.py | 3 +++ src/transformers/models/glm4v/processing_glm4v.py | 3 +++ src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py | 3 +++ .../models/qwen2_5_vl/processing_qwen2_5_vl.py | 3 +++ src/transformers/models/qwen2_vl/processing_qwen2_vl.py | 3 +++ src/transformers/models/qwen3_vl/modular_qwen3_vl.py | 3 +++ src/transformers/models/qwen3_vl/processing_qwen3_vl.py | 3 +++ .../models/video_llama_3/modular_video_llama_3.py | 2 ++ .../models/video_llama_3/processing_video_llama_3.py | 2 ++ 11 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_vision_utils.py b/src/transformers/modeling_vision_utils.py index 84417033eb91..cff2f599448f 100644 --- a/src/transformers/modeling_vision_utils.py +++ b/src/transformers/modeling_vision_utils.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Pure vision utility functions for computing data-dependent tensors. +"""Pure vision utility functions for pre-computing very dynamic and +data-dependent tensors that can break model capturing and tracing. All functions are standalone (no model weights) and compute tensors from ``grid_thw`` + config scalars. They are used by vision encoders and can be -precomputed before ``torch.export`` tracing since they use untraceable ops -(``repeat_interleave``, ``.tolist()``, ``nonzero()``, loops). +precomputed before `torch.compile` / ``torch.export`` tracing since they +use untraceable ops (``repeat_interleave``, ``.tolist()``, ``nonzero()``, loops). """ from __future__ import annotations @@ -56,7 +57,7 @@ def get_rotary_pos_ids(grid_thw: torch.Tensor, spatial_merge_size: int | torch.T ``pos_ids``: ``(total_tokens, 2)`` long — (row, col) position per token. """ device = grid_thw.device - if not isinstance(spatial_merge_size, torch.Tensor): + if isinstance(spatial_merge_size, int): spatial_merge_size = torch.tensor([spatial_merge_size], device=device).expand(len(grid_thw)) pos_ids = [] diff --git a/src/transformers/models/glm46v/processing_glm46v.py b/src/transformers/models/glm46v/processing_glm46v.py index 7687d55478b0..ab9713e1e4a8 100644 --- a/src/transformers/models/glm46v/processing_glm46v.py +++ b/src/transformers/models/glm46v/processing_glm46v.py @@ -27,6 +27,7 @@ from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging +from ...utils.import_utils import requires_backends from ...video_utils import VideoInput @@ -94,6 +95,7 @@ def __call__( image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] if return_extra_tensors: + requires_backends(self, ["torch"]) spatial_merge_size = self.image_processor.merge_size image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) @@ -110,6 +112,7 @@ def __call__( video_metadata = videos_inputs["video_metadata"] video_grid_thw = videos_inputs["video_grid_thw"] if return_extra_tensors: + requires_backends(self, ["torch"]) spatial_merge_size = self.video_processor.merge_size videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 56b5ae0b6a39..cc7678253a95 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -44,6 +44,7 @@ torch_compilable_check, ) from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.import_utils import requires_backends from ...utils.output_capturing import capture_outputs from ...video_utils import VideoInput from ..glm4.modeling_glm4 import Glm4MLP, Glm4RMSNorm, Glm4RotaryEmbedding, eager_attention_forward @@ -1289,6 +1290,7 @@ def __call__( image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] if return_extra_tensors: + requires_backends(self, ["torch"]) spatial_merge_size = self.image_processor.merge_size image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) @@ -1305,6 +1307,7 @@ def __call__( video_metadata = videos_inputs["video_metadata"] video_grid_thw = videos_inputs["video_grid_thw"] if return_extra_tensors: + requires_backends(self, ["torch"]) spatial_merge_size = self.video_processor.merge_size videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) diff --git a/src/transformers/models/glm4v/processing_glm4v.py b/src/transformers/models/glm4v/processing_glm4v.py index eaf8a5d90e1b..18e59c8db2e5 100644 --- a/src/transformers/models/glm4v/processing_glm4v.py +++ b/src/transformers/models/glm4v/processing_glm4v.py @@ -26,6 +26,7 @@ from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging +from ...utils.import_utils import requires_backends from ...video_utils import VideoInput @@ -93,6 +94,7 @@ def __call__( image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] if return_extra_tensors: + requires_backends(self, ["torch"]) spatial_merge_size = self.image_processor.merge_size image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) @@ -109,6 +111,7 @@ def __call__( video_metadata = videos_inputs["video_metadata"] video_grid_thw = videos_inputs["video_grid_thw"] if return_extra_tensors: + requires_backends(self, ["torch"]) spatial_merge_size = self.video_processor.merge_size videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index d9b27aafbb9c..46b670144ff2 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -38,6 +38,7 @@ from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, can_return_tuple, logging from ...utils.generic import merge_with_config_defaults +from ...utils.import_utils import requires_backends from ...utils.output_capturing import capture_outputs from ...video_utils import VideoInput from ..llama.modeling_llama import LlamaRMSNorm @@ -782,6 +783,7 @@ def __call__( image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] if return_extra_tensors: + requires_backends(self, ["torch"]) spatial_merge_size = self.image_processor.merge_size image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) @@ -790,6 +792,7 @@ def __call__( videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] if return_extra_tensors: + requires_backends(self, ["torch"]) spatial_merge_size = self.video_processor.merge_size videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) diff --git a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py index f6785781be2c..7fb4542beedc 100644 --- a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py @@ -28,6 +28,7 @@ from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring +from ...utils.import_utils import requires_backends from ...video_utils import VideoInput @@ -92,6 +93,7 @@ def __call__( image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] if return_extra_tensors: + requires_backends(self, ["torch"]) spatial_merge_size = self.image_processor.merge_size image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) @@ -100,6 +102,7 @@ def __call__( videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] if return_extra_tensors: + requires_backends(self, ["torch"]) spatial_merge_size = self.video_processor.merge_size videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) diff --git a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py index 2b8aa340e1f1..ee30f820423c 100644 --- a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py @@ -26,6 +26,7 @@ from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging +from ...utils.import_utils import requires_backends from ...video_utils import VideoInput @@ -91,6 +92,7 @@ def __call__( image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] if return_extra_tensors: + requires_backends(self, ["torch"]) spatial_merge_size = self.image_processor.merge_size image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) @@ -99,6 +101,7 @@ def __call__( videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] if return_extra_tensors: + requires_backends(self, ["torch"]) spatial_merge_size = self.video_processor.merge_size videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index c0d58de15485..ccb80817fdc3 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -39,6 +39,7 @@ from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, can_return_tuple, logging from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.import_utils import requires_backends from ...utils.output_capturing import capture_outputs from ...video_utils import VideoInput from ..llama.modeling_llama import LlamaRotaryEmbedding @@ -1168,6 +1169,7 @@ def __call__( image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] if return_extra_tensors: + requires_backends(self, ["torch"]) spatial_merge_size = self.image_processor.merge_size image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) @@ -1179,6 +1181,7 @@ def __call__( videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] if return_extra_tensors: + requires_backends(self, ["torch"]) spatial_merge_size = self.video_processor.merge_size videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) diff --git a/src/transformers/models/qwen3_vl/processing_qwen3_vl.py b/src/transformers/models/qwen3_vl/processing_qwen3_vl.py index c9098d9436b7..594994b7adb6 100644 --- a/src/transformers/models/qwen3_vl/processing_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/processing_qwen3_vl.py @@ -27,6 +27,7 @@ from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging +from ...utils.import_utils import requires_backends from ...video_utils import VideoInput @@ -108,6 +109,7 @@ def __call__( image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] if return_extra_tensors: + requires_backends(self, ["torch"]) spatial_merge_size = self.image_processor.merge_size image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) @@ -119,6 +121,7 @@ def __call__( videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] if return_extra_tensors: + requires_backends(self, ["torch"]) spatial_merge_size = self.video_processor.merge_size videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index 71ff165a53c9..ca2fb598213e 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -47,6 +47,7 @@ logging, ) from ...utils.generic import is_flash_attention_requested, merge_with_config_defaults +from ...utils.import_utils import requires_backends from ...utils.output_capturing import capture_outputs from ...video_utils import ( VideoInput, @@ -1078,6 +1079,7 @@ def __call__( image_grid_thw = image_inputs["image_grid_thw"] image_merge_sizes = image_inputs["image_merge_sizes"] if return_extra_tensors: + requires_backends(self, ["torch"]) image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, image_merge_sizes) else: diff --git a/src/transformers/models/video_llama_3/processing_video_llama_3.py b/src/transformers/models/video_llama_3/processing_video_llama_3.py index 987cc0069faf..f77f7a46d2ed 100644 --- a/src/transformers/models/video_llama_3/processing_video_llama_3.py +++ b/src/transformers/models/video_llama_3/processing_video_llama_3.py @@ -23,6 +23,7 @@ from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging +from ...utils.import_utils import requires_backends from ...video_utils import VideoInput @@ -90,6 +91,7 @@ def __call__( image_grid_thw = image_inputs["image_grid_thw"] image_merge_sizes = image_inputs["image_merge_sizes"] if return_extra_tensors: + requires_backends(self, ["torch"]) image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, image_merge_sizes) else: From 27677edaccbddf2bc4d6c1ed3b051a3fa98166c4 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 14 Apr 2026 08:50:32 +0200 Subject: [PATCH 31/56] add missing grid device --- tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py | 2 +- tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py | 2 +- tests/models/qwen3_5/test_modeling_qwen3_5.py | 4 ++-- tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py | 4 ++-- tests/models/qwen3_omni_moe/test_modeling_qwen3_omni_moe.py | 2 +- tests/models/qwen3_vl/test_modeling_qwen3_vl.py | 4 ++-- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py b/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py index 7de1544d6893..55d3137a4ab7 100644 --- a/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py +++ b/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py @@ -460,7 +460,7 @@ def test_get_rope_index_video_with_audio(self): image_grid_thw = torch.empty((0, 3), dtype=torch.long) # 3 * 2 * 2 = 12 video tokens - video_grid_thw = torch.tensor([[3, 2, 2]], dtype=torch.long) + video_grid_thw = torch.tensor([[3, 2, 2]], dtype=torch.long, device=torch_device) # num_audio_tokens = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1 # i.e.: 300 audio_seqlen -> 75 audio tokens diff --git a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py index 9876e0d61f4b..0f2a790db979 100644 --- a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py +++ b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py @@ -284,7 +284,7 @@ def test_video_forward(self): C * T * (P**2), ] ) - video_grid_thw = torch.tensor([[patch_T, patch_H, patch_W]] * B) + video_grid_thw = torch.tensor([[patch_T, patch_H, patch_W]] * B, device=torch_device) # sanity check assert pixel_values_videos.shape[0] == video_grid_thw.prod(dim=1).sum().item() diff --git a/tests/models/qwen3_5/test_modeling_qwen3_5.py b/tests/models/qwen3_5/test_modeling_qwen3_5.py index 7725d2891a33..543f1ad17036 100644 --- a/tests/models/qwen3_5/test_modeling_qwen3_5.py +++ b/tests/models/qwen3_5/test_modeling_qwen3_5.py @@ -456,7 +456,7 @@ def test_image_forward(self): channels * temporal_patch * (patch_size**2), ] ) - image_grid_thw = torch.tensor([[1, 1, 1]] * (bsz * num_images)) + image_grid_thw = torch.tensor([[1, 1, 1]] * (bsz * num_images), device=torch_device) self.assertEqual(pixel_values.shape[0], image_grid_thw.prod(dim=1).sum().item()) insertion_point = 0 @@ -509,7 +509,7 @@ def test_video_forward(self): ] ) - video_grid_thw = torch.tensor([[patch_t, patch_h, patch_w]] * (bsz * num_video)) + video_grid_thw = torch.tensor([[patch_t, patch_h, patch_w]] * (bsz * num_video), device=torch_device) self.assertEqual(pixel_values_videos.shape[0], video_grid_thw.prod(dim=1).sum().item()) input_ids[:, -1] = self.model_tester.pad_token_id diff --git a/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py b/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py index e81e4d951917..c5935af0a3be 100644 --- a/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py +++ b/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py @@ -454,7 +454,7 @@ def test_image_forward(self): channels * temporal_patch * (patch_size**2), ] ) - image_grid_thw = torch.tensor([[1, 1, 1]] * (bsz * num_images)) + image_grid_thw = torch.tensor([[1, 1, 1]] * (bsz * num_images), device=torch_device) self.assertEqual(pixel_values.shape[0], image_grid_thw.prod(dim=1).sum().item()) insertion_point = 0 @@ -507,7 +507,7 @@ def test_video_forward(self): ] ) - video_grid_thw = torch.tensor([[patch_t, patch_h, patch_w]] * (bsz * num_video)) + video_grid_thw = torch.tensor([[patch_t, patch_h, patch_w]] * (bsz * num_video), device=torch_device) self.assertEqual(pixel_values_videos.shape[0], video_grid_thw.prod(dim=1).sum().item()) input_ids[:, -1] = self.model_tester.pad_token_id diff --git a/tests/models/qwen3_omni_moe/test_modeling_qwen3_omni_moe.py b/tests/models/qwen3_omni_moe/test_modeling_qwen3_omni_moe.py index 3d4113c922ae..7c3f5064b1ff 100644 --- a/tests/models/qwen3_omni_moe/test_modeling_qwen3_omni_moe.py +++ b/tests/models/qwen3_omni_moe/test_modeling_qwen3_omni_moe.py @@ -487,7 +487,7 @@ def test_get_rope_index_video_with_audio(self): image_grid_thw = torch.empty((0, 3), dtype=torch.long) # 3 * 2 * 2 = 12 video tokens - video_grid_thw = torch.tensor([[3, 2, 2]], dtype=torch.long) + video_grid_thw = torch.tensor([[3, 2, 2]], dtype=torch.long, device=torch_device) # num_audio_tokens = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1 # i.e.: 300 audio_seqlen -> 75 audio tokens diff --git a/tests/models/qwen3_vl/test_modeling_qwen3_vl.py b/tests/models/qwen3_vl/test_modeling_qwen3_vl.py index b7e0b9053c25..e6b483d8378d 100644 --- a/tests/models/qwen3_vl/test_modeling_qwen3_vl.py +++ b/tests/models/qwen3_vl/test_modeling_qwen3_vl.py @@ -217,7 +217,7 @@ def test_image_forward(self): C * T * (P**2), ] ) - image_grid_thw = torch.tensor([[1, 1, 1]] * (B * num_images)) + image_grid_thw = torch.tensor([[1, 1, 1]] * (B * num_images), device=torch_device) self.assertEqual(pixel_values.shape[0], image_grid_thw.prod(dim=1).sum().item()) insertion_point = 0 @@ -272,7 +272,7 @@ def test_video_forward(self): ] ) - video_grid_thw = torch.tensor([[patch_T, patch_H, patch_W]] * (B * num_video)) + video_grid_thw = torch.tensor([[patch_T, patch_H, patch_W]] * (B * num_video), device=torch_device) # sanity check self.assertEqual(pixel_values_videos.shape[0], video_grid_thw.prod(dim=1).sum().item()) From 45a03e4221a72bd1838a80e6bd8a31c2f5138019 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 14 Apr 2026 10:07:40 +0200 Subject: [PATCH 32/56] keep rot emb in fp32 --- .../models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py | 2 +- .../models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py | 2 +- src/transformers/models/glm4v/modeling_glm4v.py | 2 +- src/transformers/models/glm4v/modular_glm4v.py | 2 +- src/transformers/models/glm4v_moe/modeling_glm4v_moe.py | 2 +- src/transformers/models/glm_ocr/modeling_glm_ocr.py | 2 +- src/transformers/models/glm_ocr/modular_glm_ocr.py | 2 +- src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py | 4 +++- src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py | 4 +++- src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py | 2 +- src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py | 2 +- src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py | 2 +- src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 2 +- src/transformers/models/qwen3_5/modeling_qwen3_5.py | 2 +- src/transformers/models/qwen3_5/modular_qwen3_5.py | 2 +- src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py | 2 +- .../models/qwen3_omni_moe/modeling_qwen3_omni_moe.py | 2 +- src/transformers/models/qwen3_vl/modeling_qwen3_vl.py | 2 +- src/transformers/models/qwen3_vl/modular_qwen3_vl.py | 2 +- src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py | 2 +- .../models/video_llama_3/modeling_video_llama_3.py | 2 +- .../models/video_llama_3/modular_video_llama_3.py | 2 +- 23 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index 6c4444060f2f..bbcc6b869d38 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -915,7 +915,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) diff --git a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py index a701d4e006a9..b11d2426a652 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py @@ -711,7 +711,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 92f9dc7e610c..d5047bfcc411 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -758,7 +758,7 @@ def forward( if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) - rotary_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_emb = self.rotary_pos_emb(rotary_pos_ids) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index cc7678253a95..57397d2d6c1a 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -644,7 +644,7 @@ def forward( if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) - rotary_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_emb = self.rotary_pos_emb(rotary_pos_ids) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 63f01eed7743..19cbb8629b4e 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -822,7 +822,7 @@ def forward( if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) - rotary_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_emb = self.rotary_pos_emb(rotary_pos_ids) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) diff --git a/src/transformers/models/glm_ocr/modeling_glm_ocr.py b/src/transformers/models/glm_ocr/modeling_glm_ocr.py index e3ebce52d96a..c55e79862285 100644 --- a/src/transformers/models/glm_ocr/modeling_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modeling_glm_ocr.py @@ -609,7 +609,7 @@ def forward( if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) - rotary_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_emb = self.rotary_pos_emb(rotary_pos_ids) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) diff --git a/src/transformers/models/glm_ocr/modular_glm_ocr.py b/src/transformers/models/glm_ocr/modular_glm_ocr.py index a2cf56f94954..5fc29c0c131a 100644 --- a/src/transformers/models/glm_ocr/modular_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modular_glm_ocr.py @@ -276,7 +276,7 @@ def forward( if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) - rotary_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_emb = self.rotary_pos_emb(rotary_pos_ids) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 227610f26564..e6398739c7cd 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -848,7 +848,9 @@ def forward( ) if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.config.spatial_merge_size) + # Use merge_size=1: PaddleOCR merges patches in the projector (after the encoder), + # unlike Qwen which merges inside the encoder, so rotary positions here are simple (row, col). + rotary_pos_ids = get_rotary_pos_ids(grid_thw, 1) rotary_embeddings = self.rotary_pos_emb(rotary_pos_ids) rotary_embeddings = rotary_embeddings.repeat(1, 2) diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index 5bdb8b629e9d..def3604ac6f6 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -816,7 +816,9 @@ def forward( ) if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.config.spatial_merge_size) + # Use merge_size=1: PaddleOCR merges patches in the projector (after the encoder), + # unlike Qwen which merges inside the encoder, so rotary positions here are simple (row, col). + rotary_pos_ids = get_rotary_pos_ids(grid_thw, 1) rotary_embeddings = self.rotary_pos_emb(rotary_pos_ids) rotary_embeddings = rotary_embeddings.repeat(1, 2) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index eb595e20f208..de660453b1b3 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1239,7 +1239,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 0a03f26ea67f..b89a0ce26ab9 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1594,7 +1594,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 3c643886a56d..8649476d3ba8 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -412,7 +412,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 46b670144ff2..f9893da9bba5 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -254,7 +254,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 7fc1149e9672..e94a78de6321 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -744,7 +744,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index 78b72f410f60..8b6f5305bd38 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1069,7 +1069,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index 3bded3a51c66..99d17e3cc285 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -462,7 +462,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 25fd6ca36f02..af4dd100a933 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1162,7 +1162,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index d071d0bbd8a4..cfeac1349825 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1165,7 +1165,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 8630e2027fb8..1b63d9aa7b31 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -700,7 +700,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index ccb80817fdc3..249904835b00 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -489,7 +489,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 30d5e3183dff..c7b81854d379 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -685,7 +685,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index c226765acaf3..eb7c2f0e0fb9 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -434,7 +434,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, merge_sizes) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index ca2fb598213e..e47c22c0bc85 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -397,7 +397,7 @@ def forward( if rotary_pos_ids is None: rotary_pos_ids = get_rotary_pos_ids(grid_thw, merge_sizes) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids).to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) From 5f3d2ae68b3379e1e8225c0242d055b24792d9ab Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 14 Apr 2026 10:27:59 +0200 Subject: [PATCH 33/56] fix test device --- tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py index 2de7b384d075..ee95efd9befc 100644 --- a/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py +++ b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py @@ -217,7 +217,7 @@ def test_image_forward(self): C * T * (P**2), ] ) - image_grid_thw = torch.tensor([[1, 1, 1]] * (B * num_images)) + image_grid_thw = torch.tensor([[1, 1, 1]] * (B * num_images), device=torch_device) self.assertEqual(pixel_values.shape[0], image_grid_thw.prod(dim=1).sum().item()) insertion_point = 0 @@ -272,7 +272,7 @@ def test_video_forward(self): ] ) - video_grid_thw = torch.tensor([[patch_T, patch_H, patch_W]] * (B * num_video)) + video_grid_thw = torch.tensor([[patch_T, patch_H, patch_W]] * (B * num_video), device=torch_device) # sanity check self.assertEqual(pixel_values_videos.shape[0], video_grid_thw.prod(dim=1).sum().item()) From 1feb220ad74d2d8954af5cbc2eb0943283d23729 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 15 Apr 2026 09:30:52 +0200 Subject: [PATCH 34/56] fix flm4v flex attention test --- tests/models/glm4v_moe/test_modeling_glm4v_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/glm4v_moe/test_modeling_glm4v_moe.py b/tests/models/glm4v_moe/test_modeling_glm4v_moe.py index 917a13bcefaf..dca1361cb597 100644 --- a/tests/models/glm4v_moe/test_modeling_glm4v_moe.py +++ b/tests/models/glm4v_moe/test_modeling_glm4v_moe.py @@ -51,7 +51,7 @@ def __init__( self, parent, batch_size=3, - seq_length=7, + seq_length=64, num_channels=3, ignore_index=-100, image_size=112, @@ -72,7 +72,7 @@ def __init__( "output_channels": 64, "hidden_act": "silu", "max_position_embeddings": 512, - "rope_parameters": {"type": "default", "mrope_section": [1, 1]}, + "rope_parameters": {"type": "default", "mrope_section": [2, 2], "partial_rotary_factor": 1.0}, "rope_theta": 10000, "tie_word_embeddings": True, "bos_token_id": 0, From e4c41380672c9fb3dd5332c65d1091f2d23bc47f Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 15 Apr 2026 15:25:50 +0200 Subject: [PATCH 35/56] rename to vision utils --- .../models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py | 2 +- .../models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py | 2 +- src/transformers/models/glm46v/processing_glm46v.py | 2 +- src/transformers/models/glm4v/modeling_glm4v.py | 2 +- src/transformers/models/glm4v/modular_glm4v.py | 2 +- src/transformers/models/glm4v/processing_glm4v.py | 2 +- src/transformers/models/glm4v_moe/modeling_glm4v_moe.py | 2 +- src/transformers/models/glm_image/modeling_glm_image.py | 2 +- src/transformers/models/glm_image/modular_glm_image.py | 2 +- src/transformers/models/glm_ocr/modeling_glm_ocr.py | 2 +- src/transformers/models/glm_ocr/modular_glm_ocr.py | 2 +- src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py | 2 +- src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py | 2 +- src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py | 2 +- src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py | 3 ++- src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py | 2 +- src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py | 2 +- src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 2 +- src/transformers/models/qwen2_vl/processing_qwen2_vl.py | 3 ++- src/transformers/models/qwen3_5/modeling_qwen3_5.py | 4 ++-- src/transformers/models/qwen3_5/modular_qwen3_5.py | 4 ++-- src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py | 4 ++-- .../models/qwen3_omni_moe/modeling_qwen3_omni_moe.py | 2 +- src/transformers/models/qwen3_vl/modeling_qwen3_vl.py | 4 ++-- src/transformers/models/qwen3_vl/modular_qwen3_vl.py | 4 ++-- src/transformers/models/qwen3_vl/processing_qwen3_vl.py | 4 ++-- src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py | 4 ++-- .../models/video_llama_3/modeling_video_llama_3.py | 2 +- .../models/video_llama_3/modular_video_llama_3.py | 2 +- .../models/video_llama_3/processing_video_llama_3.py | 2 +- .../{modeling_vision_utils.py => vision_utils.py} | 4 ++-- 32 files changed, 42 insertions(+), 40 deletions(-) rename src/transformers/{modeling_vision_utils.py => vision_utils.py} (98%) diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index bbcc6b869d38..da4545a3ad66 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -37,11 +37,11 @@ from ...modeling_outputs import BaseModelOutputWithPooling, MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from .configuration_ernie4_5_vl_moe import Ernie4_5_VLMoeConfig, Ernie4_5_VLMoeTextConfig, Ernie4_5_VLMoeVisionConfig diff --git a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py index b11d2426a652..d2c601f26068 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py @@ -42,7 +42,6 @@ from ...modeling_outputs import BaseModelOutputWithPooling, MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import dynamic_rope_update from ...modeling_utils import PreTrainedModel -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import ImagesKwargs, Unpack from ...utils import ( TensorType, @@ -53,6 +52,7 @@ ) from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ..ernie4_5_moe.configuration_ernie4_5_moe import Ernie4_5_MoeConfig from ..ernie4_5_moe.modeling_ernie4_5_moe import ( Ernie4_5_MoeAttention, diff --git a/src/transformers/models/glm46v/processing_glm46v.py b/src/transformers/models/glm46v/processing_glm46v.py index ab9713e1e4a8..4eafd3a4ab0b 100644 --- a/src/transformers/models/glm46v/processing_glm46v.py +++ b/src/transformers/models/glm46v/processing_glm46v.py @@ -23,12 +23,12 @@ from ...image_processing_utils import BatchFeature from ...image_utils import ImageInput -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging from ...utils.import_utils import requires_backends from ...video_utils import VideoInput +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens logger = logging.get_logger(__name__) diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index d5047bfcc411..122fb7f687a4 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -38,11 +38,11 @@ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from .configuration_glm4v import Glm4vConfig, Glm4vTextConfig, Glm4vVisionConfig diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 57397d2d6c1a..4bca864b9a1c 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -33,7 +33,6 @@ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_rope_utils import RopeParameters from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import ( @@ -47,6 +46,7 @@ from ...utils.import_utils import requires_backends from ...utils.output_capturing import capture_outputs from ...video_utils import VideoInput +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ..glm4.modeling_glm4 import Glm4MLP, Glm4RMSNorm, Glm4RotaryEmbedding, eager_attention_forward from ..qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VisionPatchEmbed, diff --git a/src/transformers/models/glm4v/processing_glm4v.py b/src/transformers/models/glm4v/processing_glm4v.py index 18e59c8db2e5..097fad40e01f 100644 --- a/src/transformers/models/glm4v/processing_glm4v.py +++ b/src/transformers/models/glm4v/processing_glm4v.py @@ -22,12 +22,12 @@ from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging from ...utils.import_utils import requires_backends from ...video_utils import VideoInput +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens logger = logging.get_logger(__name__) diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 19cbb8629b4e..4257a2f34814 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -38,11 +38,11 @@ from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, torch_compilable_check from ...utils.generic import can_return_tuple, is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from .configuration_glm4v_moe import Glm4vMoeConfig, Glm4vMoeTextConfig, Glm4vMoeVisionConfig diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py index fcbc4106d7e9..44cdb0c01789 100644 --- a/src/transformers/models/glm_image/modeling_glm_image.py +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -36,11 +36,11 @@ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from .configuration_glm_image import GlmImageConfig, GlmImageTextConfig, GlmImageVisionConfig, GlmImageVQVAEConfig diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index 6ad5a61ea1c1..ad8b0e33174a 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -28,13 +28,13 @@ from ...image_utils import ImageInput from ...modeling_outputs import BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import ImagesKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import merge_with_config_defaults from ...utils.import_utils import requires from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ..chameleon.modeling_chameleon import ChameleonVQVAE, ChameleonVQVAEModelOutput, ChameleonVQVAEVectorQuantizer from ..glm4v.configuration_glm4v import Glm4vTextConfig, Glm4vVisionConfig from ..glm4v.modeling_glm4v import ( diff --git a/src/transformers/models/glm_ocr/modeling_glm_ocr.py b/src/transformers/models/glm_ocr/modeling_glm_ocr.py index c55e79862285..0294f83ce5b9 100644 --- a/src/transformers/models/glm_ocr/modeling_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modeling_glm_ocr.py @@ -38,11 +38,11 @@ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from .configuration_glm_ocr import GlmOcrConfig, GlmOcrTextConfig, GlmOcrVisionConfig diff --git a/src/transformers/models/glm_ocr/modular_glm_ocr.py b/src/transformers/models/glm_ocr/modular_glm_ocr.py index 5fc29c0c131a..e37f37d05a60 100644 --- a/src/transformers/models/glm_ocr/modular_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modular_glm_ocr.py @@ -20,8 +20,8 @@ from ...modeling_outputs import BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...utils import auto_docstring +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ..glm4v.configuration_glm4v import Glm4vConfig, Glm4vTextConfig, Glm4vVisionConfig from ..glm4v.modeling_glm4v import ( Glm4vForConditionalGeneration, diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index e6398739c7cd..7a796c9364b5 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -42,11 +42,11 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check, torch_int from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from .configuration_paddleocr_vl import PaddleOCRTextConfig, PaddleOCRVisionConfig, PaddleOCRVLConfig diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index def3604ac6f6..268e34616983 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -37,7 +37,6 @@ from ...masking_utils import create_bidirectional_mask, create_causal_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...models.qwen2_vl.image_processing_pil_qwen2_vl import Qwen2VLImageProcessorPil from ...models.qwen2_vl.image_processing_qwen2_vl import Qwen2VLImageProcessor from ...processing_utils import ( @@ -58,6 +57,7 @@ ) from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ..ernie4_5.configuration_ernie4_5 import Ernie4_5Config from ..ernie4_5.modeling_ernie4_5 import ( Ernie4_5DecoderLayer, diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index de660453b1b3..7e72de11bb71 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -41,7 +41,6 @@ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens, get_window_index from ...processing_utils import Unpack from ...utils import ( TransformersKwargs, @@ -55,6 +54,7 @@ from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.hub import cached_file from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens, get_window_index from .configuration_qwen2_5_omni import ( Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniBigVGANConfig, diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index b89a0ce26ab9..f67cb7ec4ba4 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -33,7 +33,6 @@ from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import RopeParameters from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens, get_window_index from ...processing_utils import Unpack from ...utils import ( TransformersKwargs, @@ -47,6 +46,7 @@ from ...utils.generic import is_flash_attention_requested, merge_with_config_defaults from ...utils.hub import cached_file from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens, get_window_index from ..llama.modeling_llama import LlamaRotaryEmbedding, rotate_half from ..qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig from ..qwen2_5_vl.modeling_qwen2_5_vl import ( @@ -1365,6 +1365,7 @@ def forward( pool_indices = get_pool_indices(feature_lens) # Derive masks from chunk_lengths (traceable arithmetic + arange broadcasting) + padded_feature = padded_feature.to(self.conv1.weight.dtype) padded_mask = ( (torch.arange(padded_feature.shape[2], device=padded_feature.device) < chunk_lengths.unsqueeze(1)) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 8649476d3ba8..b837eb5dcb60 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -42,11 +42,11 @@ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens, get_window_index from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens, get_window_index from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index f9893da9bba5..62ddf74c3484 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -33,7 +33,6 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens, get_window_index from ...processing_utils import MultiModalData, ProcessingKwargs, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, can_return_tuple, logging @@ -41,6 +40,7 @@ from ...utils.import_utils import requires_backends from ...utils.output_capturing import capture_outputs from ...video_utils import VideoInput +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens, get_window_index from ..llama.modeling_llama import LlamaRMSNorm from ..qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig from ..qwen2_vl.modeling_qwen2_vl import ( diff --git a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py index 7fb4542beedc..67fb99d8414b 100644 --- a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py @@ -24,12 +24,12 @@ # limitations under the License. from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring from ...utils.import_utils import requires_backends from ...video_utils import VideoInput +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index e94a78de6321..ae9ad8aa586e 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -38,7 +38,6 @@ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import ( TransformersKwargs, @@ -53,6 +52,7 @@ merge_with_config_defaults, ) from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig, Qwen2VLVisionConfig diff --git a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py index ee30f820423c..30c5cc87ec57 100644 --- a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py @@ -22,12 +22,12 @@ from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging from ...utils.import_utils import requires_backends from ...video_utils import VideoInput +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens logger = logging.get_logger(__name__) @@ -104,6 +104,7 @@ def __call__( requires_backends(self, ["torch"]) spatial_merge_size = self.video_processor.merge_size videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) + videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) if not isinstance(text, list): diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index 8b6f5305bd38..9df09add44a5 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -43,13 +43,13 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_pos_embed_indices, get_vision_cu_seqlens -from ...modeling_vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_pos_embed_indices, get_vision_cu_seqlens +from ...vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids from .configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig, Qwen3_5VisionConfig diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index 99d17e3cc285..577e93597f29 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -26,12 +26,12 @@ from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel -from ...modeling_vision_utils import get_pos_embed_indices, get_vision_cu_seqlens -from ...modeling_vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_pos_embed_indices, get_vision_cu_seqlens +from ...vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids from ..qwen3.modeling_qwen3 import Qwen3ForCausalLM from ..qwen3_next.configuration_qwen3_next import Qwen3NextConfig from ..qwen3_next.modeling_qwen3_next import ( diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index af4dd100a933..90ab599d772c 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -44,13 +44,13 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_pos_embed_indices, get_vision_cu_seqlens -from ...modeling_vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available from ...utils.output_capturing import OutputRecorder, capture_outputs +from ...vision_utils import get_pos_embed_indices, get_vision_cu_seqlens +from ...vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids from .configuration_qwen3_5_moe import Qwen3_5MoeConfig, Qwen3_5MoeTextConfig, Qwen3_5MoeVisionConfig diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index cfeac1349825..481b5cdb92b0 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -52,7 +52,6 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_pos_embed_indices, get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import ( @@ -62,6 +61,7 @@ merge_with_config_defaults, ) from ...utils.output_capturing import OutputRecorder, capture_outputs +from ...vision_utils import get_pos_embed_indices, get_rotary_pos_ids, get_vision_cu_seqlens from .configuration_qwen3_omni_moe import ( Qwen3OmniMoeAudioEncoderConfig, Qwen3OmniMoeCode2WavConfig, diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 1b63d9aa7b31..a58ddaadfba9 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -37,12 +37,12 @@ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_pos_embed_indices, get_vision_cu_seqlens -from ...modeling_vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_pos_embed_indices, get_vision_cu_seqlens +from ...vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids from .configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 249904835b00..6f58dd51f84e 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -33,8 +33,6 @@ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_rope_utils import RopeParameters, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_pos_embed_indices, get_vision_cu_seqlens -from ...modeling_vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids from ...processing_utils import ProcessingKwargs, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, can_return_tuple, logging @@ -42,6 +40,8 @@ from ...utils.import_utils import requires_backends from ...utils.output_capturing import capture_outputs from ...video_utils import VideoInput +from ...vision_utils import get_pos_embed_indices, get_vision_cu_seqlens +from ...vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids from ..llama.modeling_llama import LlamaRotaryEmbedding from ..qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLCausalLMOutputWithPast, diff --git a/src/transformers/models/qwen3_vl/processing_qwen3_vl.py b/src/transformers/models/qwen3_vl/processing_qwen3_vl.py index 594994b7adb6..ec6ed045f31c 100644 --- a/src/transformers/models/qwen3_vl/processing_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/processing_qwen3_vl.py @@ -22,13 +22,13 @@ from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput -from ...modeling_vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids -from ...modeling_vision_utils import get_vision_cu_seqlens from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging from ...utils.import_utils import requires_backends from ...video_utils import VideoInput +from ...vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids +from ...vision_utils import get_vision_cu_seqlens logger = logging.get_logger(__name__) diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index c7b81854d379..0ef879dc96ba 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -43,12 +43,12 @@ from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_pos_embed_indices, get_vision_cu_seqlens -from ...modeling_vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs +from ...vision_utils import get_pos_embed_indices, get_vision_cu_seqlens +from ...vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids from .configuration_qwen3_vl_moe import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig, Qwen3VLMoeVisionConfig diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index eb7c2f0e0fb9..8028fd2e9f73 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -32,11 +32,11 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import is_flash_attention_requested, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ..auto.modeling_auto import AutoModel from .configuration_video_llama_3 import VideoLlama3Config, VideoLlama3VisionConfig diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index e47c22c0bc85..1af49e9c4551 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -37,7 +37,6 @@ ) from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import ImagesKwargs, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import ( @@ -54,6 +53,7 @@ group_videos_by_shape, reorder_videos, ) +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ..auto import CONFIG_MAPPING, AutoConfig from ..auto.modeling_auto import AutoModel from ..qwen2_vl.image_processing_pil_qwen2_vl import Qwen2VLImageProcessorPil diff --git a/src/transformers/models/video_llama_3/processing_video_llama_3.py b/src/transformers/models/video_llama_3/processing_video_llama_3.py index f77f7a46d2ed..4e4835e21fb3 100644 --- a/src/transformers/models/video_llama_3/processing_video_llama_3.py +++ b/src/transformers/models/video_llama_3/processing_video_llama_3.py @@ -19,12 +19,12 @@ # limitations under the License. from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput -from ...modeling_vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging from ...utils.import_utils import requires_backends from ...video_utils import VideoInput +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens logger = logging.get_logger(__name__) diff --git a/src/transformers/modeling_vision_utils.py b/src/transformers/vision_utils.py similarity index 98% rename from src/transformers/modeling_vision_utils.py rename to src/transformers/vision_utils.py index cff2f599448f..26e94ffe3fed 100644 --- a/src/transformers/modeling_vision_utils.py +++ b/src/transformers/vision_utils.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Pure vision utility functions for pre-computing very dynamic and -data-dependent tensors that can break model capturing and tracing. +"""Vision utility functions for pre-computing very dynamic and +data-dependent tensors that can break model graph capturing. All functions are standalone (no model weights) and compute tensors from ``grid_thw`` + config scalars. They are used by vision encoders and can be From d401a3328a170db6f91345d88fa7b998ef811719 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 15 Apr 2026 16:19:05 +0200 Subject: [PATCH 36/56] only one get_rotary_pos_ids is needed --- .../models/qwen3_5/modeling_qwen3_5.py | 3 +- .../models/qwen3_5/modular_qwen3_5.py | 3 +- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 3 +- .../models/qwen3_vl/modeling_qwen3_vl.py | 3 +- .../models/qwen3_vl/modular_qwen3_vl.py | 3 +- .../models/qwen3_vl/processing_qwen3_vl.py | 3 +- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 3 +- src/transformers/vision_utils.py | 49 ------------------- 8 files changed, 7 insertions(+), 63 deletions(-) diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index 9df09add44a5..c7529473c881 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -48,8 +48,7 @@ from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available from ...utils.output_capturing import capture_outputs -from ...vision_utils import get_pos_embed_indices, get_vision_cu_seqlens -from ...vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids +from ...vision_utils import get_pos_embed_indices, get_rotary_pos_ids, get_vision_cu_seqlens from .configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig, Qwen3_5VisionConfig diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index 577e93597f29..aa8aef5dc0b7 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -30,8 +30,7 @@ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ...vision_utils import get_pos_embed_indices, get_vision_cu_seqlens -from ...vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids +from ...vision_utils import get_pos_embed_indices, get_rotary_pos_ids, get_vision_cu_seqlens from ..qwen3.modeling_qwen3 import Qwen3ForCausalLM from ..qwen3_next.configuration_qwen3_next import Qwen3NextConfig from ..qwen3_next.modeling_qwen3_next import ( diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 90ab599d772c..63816fa8341f 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -49,8 +49,7 @@ from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available from ...utils.output_capturing import OutputRecorder, capture_outputs -from ...vision_utils import get_pos_embed_indices, get_vision_cu_seqlens -from ...vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids +from ...vision_utils import get_pos_embed_indices, get_rotary_pos_ids, get_vision_cu_seqlens from .configuration_qwen3_5_moe import Qwen3_5MoeConfig, Qwen3_5MoeTextConfig, Qwen3_5MoeVisionConfig diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index a58ddaadfba9..095792ae03af 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -41,8 +41,7 @@ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ...vision_utils import get_pos_embed_indices, get_vision_cu_seqlens -from ...vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids +from ...vision_utils import get_pos_embed_indices, get_rotary_pos_ids, get_vision_cu_seqlens from .configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 6f58dd51f84e..bea38ae9afe9 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -40,8 +40,7 @@ from ...utils.import_utils import requires_backends from ...utils.output_capturing import capture_outputs from ...video_utils import VideoInput -from ...vision_utils import get_pos_embed_indices, get_vision_cu_seqlens -from ...vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids +from ...vision_utils import get_pos_embed_indices, get_rotary_pos_ids, get_vision_cu_seqlens from ..llama.modeling_llama import LlamaRotaryEmbedding from ..qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLCausalLMOutputWithPast, diff --git a/src/transformers/models/qwen3_vl/processing_qwen3_vl.py b/src/transformers/models/qwen3_vl/processing_qwen3_vl.py index ec6ed045f31c..eea52dad8a6a 100644 --- a/src/transformers/models/qwen3_vl/processing_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/processing_qwen3_vl.py @@ -27,8 +27,7 @@ from ...utils import auto_docstring, logging from ...utils.import_utils import requires_backends from ...video_utils import VideoInput -from ...vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids -from ...vision_utils import get_vision_cu_seqlens +from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens logger = logging.get_logger(__name__) diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 0ef879dc96ba..8854c7a117ac 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -47,8 +47,7 @@ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs -from ...vision_utils import get_pos_embed_indices, get_vision_cu_seqlens -from ...vision_utils import get_rotary_pos_ids_interleaved as get_rotary_pos_ids +from ...vision_utils import get_pos_embed_indices, get_rotary_pos_ids, get_vision_cu_seqlens from .configuration_qwen3_vl_moe import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig, Qwen3VLMoeVisionConfig diff --git a/src/transformers/vision_utils.py b/src/transformers/vision_utils.py index 26e94ffe3fed..b80a48bdfbde 100644 --- a/src/transformers/vision_utils.py +++ b/src/transformers/vision_utils.py @@ -73,55 +73,6 @@ def get_rotary_pos_ids(grid_thw: torch.Tensor, spatial_merge_size: int | torch.T return torch.cat(pos_ids, dim=0) -def get_rotary_pos_ids_interleaved(grid_thw: torch.Tensor, spatial_merge_size: int) -> torch.Tensor: - """Compute (row, col) position IDs for Qwen3-VL style vision rotary embeddings. - - Uses block-interleaved positions with intra-block offsets (different from the - Qwen2-VL variant which permutes whole rows/columns). - - Args: - grid_thw: ``(num_images_or_videos, 3)`` - spatial_merge_size: merge block size from vision config. - - Returns: - ``pos_ids``: ``(total_tokens, 2)`` long — (row, col) position per token. - """ - m = spatial_merge_size - device = grid_thw.device - total_tokens = int((grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).sum().item()) - pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) - - offset = 0 - for num_frames, height, width in grid_thw.tolist(): - num_frames, height, width = int(num_frames), int(height), int(width) - merged_h, merged_w = height // m, width // m - - block_rows = torch.arange(merged_h, device=device) - block_cols = torch.arange(merged_w, device=device) - intra_row = torch.arange(m, device=device) - intra_col = torch.arange(m, device=device) - - row_idx = ( - (block_rows[:, None, None, None] * m + intra_row[None, None, :, None]) - .expand(merged_h, merged_w, m, m) - .reshape(-1) - ) - col_idx = ( - (block_cols[None, :, None, None] * m + intra_col[None, None, None, :]) - .expand(merged_h, merged_w, m, m) - .reshape(-1) - ) - - coords = torch.stack((row_idx, col_idx), dim=-1) - if num_frames > 1: - coords = coords.repeat(num_frames, 1) - - num_tokens = coords.shape[0] - pos_ids[offset : offset + num_tokens] = coords - offset += num_tokens - - return pos_ids - def get_window_index( grid_thw: torch.Tensor, From fc49a3f9f41254f0e66c3af1cb3d5c6e9333ed0f Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 15 Apr 2026 16:19:32 +0200 Subject: [PATCH 37/56] style --- src/transformers/vision_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/vision_utils.py b/src/transformers/vision_utils.py index b80a48bdfbde..75c154594b29 100644 --- a/src/transformers/vision_utils.py +++ b/src/transformers/vision_utils.py @@ -73,7 +73,6 @@ def get_rotary_pos_ids(grid_thw: torch.Tensor, spatial_merge_size: int | torch.T return torch.cat(pos_ids, dim=0) - def get_window_index( grid_thw: torch.Tensor, spatial_merge_size: int, From 4711af6acb5c918dc3b180894a1961fdd0dcb2b8 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 15 Apr 2026 16:22:10 +0200 Subject: [PATCH 38/56] style --- src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index f67cb7ec4ba4..23f8c6cf9f31 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1365,7 +1365,6 @@ def forward( pool_indices = get_pool_indices(feature_lens) # Derive masks from chunk_lengths (traceable arithmetic + arange broadcasting) - padded_feature = padded_feature.to(self.conv1.weight.dtype) padded_mask = ( (torch.arange(padded_feature.shape[2], device=padded_feature.device) < chunk_lengths.unsqueeze(1)) From e85551e03452c2b1613e0f79a44409926c919b9b Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Apr 2026 15:48:48 +0200 Subject: [PATCH 39/56] deprecate only --- .../modeling_ernie4_5_vl_moe.py | 11 +++ .../models/glm4v/modeling_glm4v.py | 11 +++ .../models/glm4v/modular_glm4v.py | 11 +++ .../models/glm4v_moe/modeling_glm4v_moe.py | 11 +++ .../models/glm_image/modeling_glm_image.py | 9 ++ .../models/glm_image/modular_glm_image.py | 9 ++ .../models/glm_ocr/modeling_glm_ocr.py | 11 +++ .../qwen2_5_omni/modeling_qwen2_5_omni.py | 69 ++++++++++++++ .../qwen2_5_omni/modular_qwen2_5_omni.py | 44 +++++++++ .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 26 ++++++ .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 26 ++++++ .../models/qwen2_vl/modeling_qwen2_vl.py | 11 +++ .../models/qwen3_5/modeling_qwen3_5.py | 22 +++++ .../qwen3_5_moe/modeling_qwen3_5_moe.py | 22 +++++ .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 89 +++++++++++++++++++ .../models/qwen3_vl/modeling_qwen3_vl.py | 22 +++++ .../models/qwen3_vl/modular_qwen3_vl.py | 22 +++++ .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 22 +++++ 18 files changed, 448 insertions(+) diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index 23ebeafbe4a8..a578f08b2639 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -19,6 +19,7 @@ # limitations under the License. import itertools +import warnings from collections.abc import Callable from typing import Any, Optional @@ -932,6 +933,16 @@ def forward( hidden_states = self.ln(hidden_states) return BaseModelOutputWithPooling(last_hidden_state=hidden_states) + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(pos_ids) + return rotary_pos_emb + class Ernie4_5_VLMoeVisionMLP(nn.Module): def __init__(self, config, in_dim, out_dim): diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 5cb9401439a8..8a902370e39b 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -794,6 +795,16 @@ def forward( pooler_output=merged_hidden_states, ) + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(pos_ids) + return rotary_pos_emb, pos_ids + @auto_docstring class Glm4vTextModel(Glm4vPreTrainedModel): diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 447cffb0d227..c0e31a087c58 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from collections.abc import Callable import numpy as np @@ -679,6 +680,16 @@ def forward( pooler_output=merged_hidden_states, ) + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(pos_ids) + return rotary_pos_emb, pos_ids + class Glm4vTextModel(Qwen2_5_VLTextModel): _can_record_outputs = { diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index c10fb9001cf5..63aade7b890a 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -858,6 +859,16 @@ def forward( pooler_output=merged_hidden_states, ) + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(pos_ids) + return rotary_pos_emb, pos_ids + class Glm4vMoeTextRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py index 7179c32597ec..f3510aa3247e 100644 --- a/src/transformers/models/glm_image/modeling_glm_image.py +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -639,6 +640,14 @@ def forward( return BaseModelOutputWithPooling(last_hidden_state=hidden_states) + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + return get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + @use_kernel_forward_from_hub("RMSNorm") class GlmImageRMSNorm(nn.Module): diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index ad8b0e33174a..d5d7134f508b 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -13,6 +13,7 @@ # limitations under the License. import math +import warnings from collections.abc import Callable from typing import Any @@ -479,6 +480,14 @@ def forward( return BaseModelOutputWithPooling(last_hidden_state=hidden_states) + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + return get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + class GlmImageTextModel(Glm4vTextModel): pass diff --git a/src/transformers/models/glm_ocr/modeling_glm_ocr.py b/src/transformers/models/glm_ocr/modeling_glm_ocr.py index a0b32e31b3f5..e1702e136461 100644 --- a/src/transformers/models/glm_ocr/modeling_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modeling_glm_ocr.py @@ -19,6 +19,7 @@ # limitations under the License. import itertools +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -634,6 +635,16 @@ def forward( pooler_output=merged_hidden_states, ) + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(pos_ids) + return rotary_pos_emb, pos_ids + class GlmOcrTextRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 7e72de11bb71..89f648c69265 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -20,6 +20,7 @@ # limitations under the License. import math +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -940,6 +941,49 @@ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): output_lengths = (input_lengths - 2) // 2 + 1 return input_lengths, output_lengths + def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): + """ + Pads a sequence of tensors to their maximum length on indicated `padding_side`. + Then prepares a mask so that pad tokens are not attended to. + """ + warnings.warn( + f"`{self.__class__.__name__}.padded_and_mask_function` is deprecated and will be removed in a future version. Use `chunk_and_pad_features` and `get_audio_cu_seqlens` helpers instead.", + FutureWarning, + stacklevel=2, + ) + max_len = tensor_len.max() + dim = tensor_list[0].shape[0] + padded_tensor = torch.full( + size=(len(tensor_list), dim, max_len), + fill_value=padding_value, + dtype=self.dtype, + device=tensor_list[0].device, + ) + + batch_mask = torch.zeros( + (len(tensor_len), max_len), + dtype=torch.long, + device=padded_tensor.device, + ) + for i, length in enumerate(tensor_len): + batch_mask[i, :length] = 1 + padded_tensor[i, :, :length] = tensor_list[i] + + feature_lens_after_cnn = (tensor_len - 1) // 2 + 1 + max_len_after_cnn = feature_lens_after_cnn.max() + batch_mask_after_cnn = torch.zeros( + (len(tensor_len), max_len_after_cnn), + dtype=torch.long, + device=padded_tensor.device, + ) + for i, length in enumerate(feature_lens_after_cnn): + batch_mask_after_cnn[i, :length] = 1 + return ( + padded_tensor, + batch_mask.unsqueeze(1), + batch_mask_after_cnn.bool(), + ) + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -1280,6 +1324,31 @@ def forward( pooler_output=merged_hidden_states, ) + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(pos_ids) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.get_window_index` is deprecated and will be removed in a future version. Use `get_window_index` from `transformers.vision_utils` instead.", + FutureWarning, + stacklevel=2, + ) + window_index, cu_window_seqlens = get_window_index( + grid_thw, + self.spatial_merge_size, + self.window_size, + self.patch_size, + self.spatial_merge_unit, + ) + return window_index, cu_window_seqlens.tolist() + class Qwen2_5OmniRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 23f8c6cf9f31..90abc7a27cf7 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -15,6 +15,7 @@ """PyTorch Qwen2.5Omni model (Audio, Image, Video).""" import math +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -1422,6 +1423,49 @@ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): output_lengths = (input_lengths - 2) // 2 + 1 return input_lengths, output_lengths + def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): + """ + Pads a sequence of tensors to their maximum length on indicated `padding_side`. + Then prepares a mask so that pad tokens are not attended to. + """ + warnings.warn( + f"`{self.__class__.__name__}.padded_and_mask_function` is deprecated and will be removed in a future version. Use `chunk_and_pad_features` and `get_audio_cu_seqlens` helpers instead.", + FutureWarning, + stacklevel=2, + ) + max_len = tensor_len.max() + dim = tensor_list[0].shape[0] + padded_tensor = torch.full( + size=(len(tensor_list), dim, max_len), + fill_value=padding_value, + dtype=self.dtype, + device=tensor_list[0].device, + ) + + batch_mask = torch.zeros( + (len(tensor_len), max_len), + dtype=torch.long, + device=padded_tensor.device, + ) + for i, length in enumerate(tensor_len): + batch_mask[i, :length] = 1 + padded_tensor[i, :, :length] = tensor_list[i] + + feature_lens_after_cnn = (tensor_len - 1) // 2 + 1 + max_len_after_cnn = feature_lens_after_cnn.max() + batch_mask_after_cnn = torch.zeros( + (len(tensor_len), max_len_after_cnn), + dtype=torch.long, + device=padded_tensor.device, + ) + for i, length in enumerate(feature_lens_after_cnn): + batch_mask_after_cnn[i, :length] = 1 + return ( + padded_tensor, + batch_mask.unsqueeze(1), + batch_mask_after_cnn.bool(), + ) + def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: orig_dtype = tensor.dtype diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 351523fe2752..9e54f36e06fe 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -24,6 +24,7 @@ # limitations under the License. import itertools +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -455,6 +456,31 @@ def forward( pooler_output=merged_hidden_states, ) + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(pos_ids) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.get_window_index` is deprecated and will be removed in a future version. Use `get_window_index` from `transformers.vision_utils` instead.", + FutureWarning, + stacklevel=2, + ) + window_index, cu_window_seqlens = get_window_index( + grid_thw, + self.spatial_merge_size, + self.window_size, + self.patch_size, + self.spatial_merge_unit, + ) + return window_index, cu_window_seqlens.tolist() + @dataclass @auto_docstring( diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 62ddf74c3484..4632bff776fd 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -19,6 +19,7 @@ """PyTorch Qwen2.5-VL model.""" import itertools +import warnings import torch import torch.nn as nn @@ -297,6 +298,31 @@ def forward( pooler_output=merged_hidden_states, ) + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(pos_ids) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.get_window_index` is deprecated and will be removed in a future version. Use `get_window_index` from `transformers.vision_utils` instead.", + FutureWarning, + stacklevel=2, + ) + window_index, cu_window_seqlens = get_window_index( + grid_thw, + self.spatial_merge_size, + self.window_size, + self.patch_size, + self.spatial_merge_unit, + ) + return window_index, cu_window_seqlens.tolist() + class Qwen2_5_VLModelOutputWithPast(Qwen2VLModelOutputWithPast): pass diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 4827808e17ab..32897705ad41 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -19,6 +19,7 @@ """PyTorch Qwen2-VL model.""" import itertools +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -767,6 +768,16 @@ def forward( pooler_output=merged_hidden_states, ) + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(pos_ids) + return rotary_pos_emb + @auto_docstring class Qwen2VLTextModel(Qwen2VLPreTrainedModel): diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index bb7f82848584..51922261511a 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -19,6 +19,7 @@ # limitations under the License. import itertools +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -1094,6 +1095,27 @@ def forward( pooler_output=merged_hidden_states, ) + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(pos_ids) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_pos_embed_indices` from `transformers.vision_utils` and apply `self.pos_embed`.", + FutureWarning, + stacklevel=2, + ) + embed_indices, bilinear_weights = get_pos_embed_indices( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + return (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) + @dataclass @auto_docstring( diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 1bda96d436c7..dd35f8887fb1 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -19,6 +19,7 @@ # limitations under the License. import itertools +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -1187,6 +1188,27 @@ def forward( pooler_output=merged_hidden_states, ) + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(pos_ids) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_pos_embed_indices` from `transformers.vision_utils` and apply `self.pos_embed`.", + FutureWarning, + stacklevel=2, + ) + embed_indices, bilinear_weights = get_pos_embed_indices( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + return (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) + @dataclass @auto_docstring( diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index c58561deb4e5..f0f070fca761 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -20,6 +20,7 @@ # limitations under the License. import math +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Optional @@ -848,6 +849,73 @@ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): output_lengths = (input_lengths - 2) // 2 + 1 return input_lengths, output_lengths + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + warnings.warn( + f"`{self.__class__.__name__}._prepare_attention_mask` is deprecated and will be removed in a future version.", + FutureWarning, + stacklevel=2, + ) + if is_flash_attention_requested(self.config): + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + + def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): + """ + Pads a sequence of tensors to their maximum length on indicated `padding_side`. + Then prepares a mask so that pad tokens are not attended to. + """ + warnings.warn( + f"`{self.__class__.__name__}.padded_and_mask_function` is deprecated and will be removed in a future version. Use `chunk_and_pad_features` and `get_audio_cu_seqlens` helpers instead.", + FutureWarning, + stacklevel=2, + ) + max_len = tensor_len.max() + dim = tensor_list[0].shape[0] + padded_tensor = torch.full( + size=(len(tensor_list), dim, max_len), + fill_value=padding_value, + dtype=self.dtype, + device=tensor_list[0].device, + ) + + batch_mask = torch.zeros( + (len(tensor_len), max_len), + dtype=torch.long, + device=padded_tensor.device, + ) + for i, length in enumerate(tensor_len): + batch_mask[i, :length] = 1 + padded_tensor[i, :, :length] = tensor_list[i] + + feature_lens_after_cnn = (tensor_len - 1) // 2 + 1 + max_len_after_cnn = feature_lens_after_cnn.max() + batch_mask_after_cnn = torch.zeros( + (len(tensor_len), max_len_after_cnn), + dtype=torch.long, + device=padded_tensor.device, + ) + for i, length in enumerate(feature_lens_after_cnn): + batch_mask_after_cnn[i, :length] = 1 + return ( + padded_tensor, + batch_mask.unsqueeze(1), + batch_mask_after_cnn.bool(), + ) + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -1198,6 +1266,27 @@ def forward( deepstack_features=deepstack_feature_lists, ) + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(pos_ids) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_pos_embed_indices` from `transformers.vision_utils` and apply `self.pos_embed`.", + FutureWarning, + stacklevel=2, + ) + embed_indices, bilinear_weights = get_pos_embed_indices( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + return (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) + @property def deepstack_merger_list(self): return self.merger_list diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 5d3b9966d36a..953897234a25 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -19,6 +19,7 @@ # limitations under the License. import itertools +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -732,6 +733,27 @@ def forward( deepstack_features=deepstack_feature_lists, ) + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(pos_ids) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_pos_embed_indices` from `transformers.vision_utils` and apply `self.pos_embed`.", + FutureWarning, + stacklevel=2, + ) + embed_indices, bilinear_weights = get_pos_embed_indices( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + return (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) + @auto_docstring( custom_intro=( diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index bea38ae9afe9..f9c627b7461a 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -13,6 +13,7 @@ # limitations under the License. """PyTorch Qwen3-VL model.""" +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -521,6 +522,27 @@ def forward( deepstack_features=deepstack_feature_lists, ) + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(pos_ids) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_pos_embed_indices` from `transformers.vision_utils` and apply `self.pos_embed`.", + FutureWarning, + stacklevel=2, + ) + embed_indices, bilinear_weights = get_pos_embed_indices( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + return (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) + @auto_docstring( custom_intro=( diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index af48f91050dc..a1f5125dec97 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -19,6 +19,7 @@ # limitations under the License. import itertools +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -717,6 +718,27 @@ def forward( deepstack_features=deepstack_feature_lists, ) + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(pos_ids) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_pos_embed_indices` from `transformers.vision_utils` and apply `self.pos_embed`.", + FutureWarning, + stacklevel=2, + ) + embed_indices, bilinear_weights = get_pos_embed_indices( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + return (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) + class Qwen3VLMoeTextRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` From 531f13cc7b6d176a2062e129705cf1238ce0e0b6 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Apr 2026 15:51:38 +0200 Subject: [PATCH 40/56] fix --- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index f0f070fca761..c787f7d76bc5 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -849,30 +849,6 @@ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): output_lengths = (input_lengths - 2) // 2 + 1 return input_lengths, output_lengths - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - warnings.warn( - f"`{self.__class__.__name__}._prepare_attention_mask` is deprecated and will be removed in a future version.", - FutureWarning, - stacklevel=2, - ) - if is_flash_attention_requested(self.config): - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): """ Pads a sequence of tensors to their maximum length on indicated `padding_side`. From 4c3d84d63bc5c62d48d55d12fec2f79387842ebb Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Apr 2026 17:18:23 +0200 Subject: [PATCH 41/56] simplify and revert processor changes --- .../models/glm46v/processing_glm46v.py | 13 ------------- src/transformers/models/glm4v/modular_glm4v.py | 12 ------------ src/transformers/models/glm4v/processing_glm4v.py | 13 ------------- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 12 ------------ .../models/qwen2_5_vl/processing_qwen2_5_vl.py | 13 ------------- .../models/qwen2_vl/processing_qwen2_vl.py | 14 -------------- .../models/qwen3_vl/modular_qwen3_vl.py | 12 ------------ .../models/qwen3_vl/processing_qwen3_vl.py | 13 ------------- .../models/video_llama_3/modular_video_llama_3.py | 6 ------ .../video_llama_3/processing_video_llama_3.py | 7 ------- 10 files changed, 115 deletions(-) diff --git a/src/transformers/models/glm46v/processing_glm46v.py b/src/transformers/models/glm46v/processing_glm46v.py index 4eafd3a4ab0b..9dcf7c4856e6 100644 --- a/src/transformers/models/glm46v/processing_glm46v.py +++ b/src/transformers/models/glm46v/processing_glm46v.py @@ -26,9 +26,7 @@ from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging -from ...utils.import_utils import requires_backends from ...video_utils import VideoInput -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens logger = logging.get_logger(__name__) @@ -85,7 +83,6 @@ def __call__( - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. """ - return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( Glm46VProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, @@ -94,11 +91,6 @@ def __call__( if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] - if return_extra_tensors: - requires_backends(self, ["torch"]) - spatial_merge_size = self.image_processor.merge_size - image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) - image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) else: image_inputs = {} image_grid_thw = None @@ -111,11 +103,6 @@ def __call__( else: video_metadata = videos_inputs["video_metadata"] video_grid_thw = videos_inputs["video_grid_thw"] - if return_extra_tensors: - requires_backends(self, ["torch"]) - spatial_merge_size = self.video_processor.merge_size - videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) - videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) else: videos_inputs = {} video_grid_thw = None diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index c0e31a087c58..a3b754463b1d 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -43,7 +43,6 @@ torch_compilable_check, ) from ...utils.generic import maybe_autocast, merge_with_config_defaults -from ...utils.import_utils import requires_backends from ...utils.output_capturing import capture_outputs from ...video_utils import VideoInput from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens @@ -1210,7 +1209,6 @@ def __call__( - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. """ - return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( Glm4vProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, @@ -1219,11 +1217,6 @@ def __call__( if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] - if return_extra_tensors: - requires_backends(self, ["torch"]) - spatial_merge_size = self.image_processor.merge_size - image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) - image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) else: image_inputs = {} image_grid_thw = None @@ -1236,11 +1229,6 @@ def __call__( else: video_metadata = videos_inputs["video_metadata"] video_grid_thw = videos_inputs["video_grid_thw"] - if return_extra_tensors: - requires_backends(self, ["torch"]) - spatial_merge_size = self.video_processor.merge_size - videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) - videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) else: videos_inputs = {} video_grid_thw = None diff --git a/src/transformers/models/glm4v/processing_glm4v.py b/src/transformers/models/glm4v/processing_glm4v.py index 097fad40e01f..2d3e93aec9ed 100644 --- a/src/transformers/models/glm4v/processing_glm4v.py +++ b/src/transformers/models/glm4v/processing_glm4v.py @@ -25,9 +25,7 @@ from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging -from ...utils.import_utils import requires_backends from ...video_utils import VideoInput -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens logger = logging.get_logger(__name__) @@ -84,7 +82,6 @@ def __call__( - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. """ - return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( Glm4vProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, @@ -93,11 +90,6 @@ def __call__( if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] - if return_extra_tensors: - requires_backends(self, ["torch"]) - spatial_merge_size = self.image_processor.merge_size - image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) - image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) else: image_inputs = {} image_grid_thw = None @@ -110,11 +102,6 @@ def __call__( else: video_metadata = videos_inputs["video_metadata"] video_grid_thw = videos_inputs["video_grid_thw"] - if return_extra_tensors: - requires_backends(self, ["torch"]) - spatial_merge_size = self.video_processor.merge_size - videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) - videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) else: videos_inputs = {} video_grid_thw = None diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 4632bff776fd..d3c3a3987d71 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -38,7 +38,6 @@ from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, can_return_tuple, logging from ...utils.generic import merge_with_config_defaults -from ...utils.import_utils import requires_backends from ...utils.output_capturing import capture_outputs from ...video_utils import VideoInput from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens, get_window_index @@ -797,7 +796,6 @@ def __call__( - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. """ - return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( Qwen2_5_VLProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, @@ -808,20 +806,10 @@ def __call__( if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] - if return_extra_tensors: - requires_backends(self, ["torch"]) - spatial_merge_size = self.image_processor.merge_size - image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) - image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) if videos is not None: videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] - if return_extra_tensors: - requires_backends(self, ["torch"]) - spatial_merge_size = self.video_processor.merge_size - videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) - videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) # Get video metadata if not kwargs.get("return_metadata"): diff --git a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py index 67fb99d8414b..8873eb82557a 100644 --- a/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py @@ -27,9 +27,7 @@ from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring -from ...utils.import_utils import requires_backends from ...video_utils import VideoInput -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): @@ -81,7 +79,6 @@ def __call__( - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. """ - return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( Qwen2_5_VLProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, @@ -92,20 +89,10 @@ def __call__( if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] - if return_extra_tensors: - requires_backends(self, ["torch"]) - spatial_merge_size = self.image_processor.merge_size - image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) - image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) if videos is not None: videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] - if return_extra_tensors: - requires_backends(self, ["torch"]) - spatial_merge_size = self.video_processor.merge_size - videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) - videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) # Get video metadata if not kwargs.get("return_metadata"): diff --git a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py index 30c5cc87ec57..9c38451e60e8 100644 --- a/src/transformers/models/qwen2_vl/processing_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/processing_qwen2_vl.py @@ -25,9 +25,7 @@ from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging -from ...utils.import_utils import requires_backends from ...video_utils import VideoInput -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens logger = logging.get_logger(__name__) @@ -80,7 +78,6 @@ def __call__( - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. """ - return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( Qwen2VLProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, @@ -91,21 +88,10 @@ def __call__( if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] - if return_extra_tensors: - requires_backends(self, ["torch"]) - spatial_merge_size = self.image_processor.merge_size - image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) - image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) if videos is not None: videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] - if return_extra_tensors: - requires_backends(self, ["torch"]) - spatial_merge_size = self.video_processor.merge_size - videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) - - videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) if not isinstance(text, list): text = [text] diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index f9c627b7461a..ab885673a9eb 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -38,7 +38,6 @@ from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, can_return_tuple, logging from ...utils.generic import maybe_autocast, merge_with_config_defaults -from ...utils.import_utils import requires_backends from ...utils.output_capturing import capture_outputs from ...video_utils import VideoInput from ...vision_utils import get_pos_embed_indices, get_rotary_pos_ids, get_vision_cu_seqlens @@ -1180,7 +1179,6 @@ def __call__( - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. """ - return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( Qwen3VLProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, @@ -1189,11 +1187,6 @@ def __call__( if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] - if return_extra_tensors: - requires_backends(self, ["torch"]) - spatial_merge_size = self.image_processor.merge_size - image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) - image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) else: image_inputs = {} image_grid_thw = None @@ -1201,11 +1194,6 @@ def __call__( if videos is not None: videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] - if return_extra_tensors: - requires_backends(self, ["torch"]) - spatial_merge_size = self.video_processor.merge_size - videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) - videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) # If user has not requested video metadata, pop it if not kwargs.get("return_metadata"): video_metadata = videos_inputs.pop("video_metadata") diff --git a/src/transformers/models/qwen3_vl/processing_qwen3_vl.py b/src/transformers/models/qwen3_vl/processing_qwen3_vl.py index eea52dad8a6a..1ca435749ad2 100644 --- a/src/transformers/models/qwen3_vl/processing_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/processing_qwen3_vl.py @@ -25,9 +25,7 @@ from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging -from ...utils.import_utils import requires_backends from ...video_utils import VideoInput -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens logger = logging.get_logger(__name__) @@ -98,7 +96,6 @@ def __call__( - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. """ - return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( Qwen3VLProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, @@ -107,11 +104,6 @@ def __call__( if images is not None: image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] - if return_extra_tensors: - requires_backends(self, ["torch"]) - spatial_merge_size = self.image_processor.merge_size - image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) - image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, spatial_merge_size) else: image_inputs = {} image_grid_thw = None @@ -119,11 +111,6 @@ def __call__( if videos is not None: videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) video_grid_thw = videos_inputs["video_grid_thw"] - if return_extra_tensors: - requires_backends(self, ["torch"]) - spatial_merge_size = self.video_processor.merge_size - videos_inputs["video_cu_seqlens"] = get_vision_cu_seqlens(video_grid_thw) - videos_inputs["video_rotary_pos_ids"] = get_rotary_pos_ids(video_grid_thw, spatial_merge_size) # If user has not requested video metadata, pop it if not kwargs.get("return_metadata"): video_metadata = videos_inputs.pop("video_metadata") diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index 1af49e9c4551..ee3ec500cbdb 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -46,7 +46,6 @@ logging, ) from ...utils.generic import is_flash_attention_requested, merge_with_config_defaults -from ...utils.import_utils import requires_backends from ...utils.output_capturing import capture_outputs from ...video_utils import ( VideoInput, @@ -1066,7 +1065,6 @@ def __call__( videos: VideoInput = None, **kwargs: Unpack[VideoLlama3ProcessorKwargs], ) -> BatchFeature: - return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( VideoLlama3ProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, @@ -1078,10 +1076,6 @@ def __call__( image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] image_merge_sizes = image_inputs["image_merge_sizes"] - if return_extra_tensors: - requires_backends(self, ["torch"]) - image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) - image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, image_merge_sizes) else: image_grid_thw = image_merge_sizes = [] diff --git a/src/transformers/models/video_llama_3/processing_video_llama_3.py b/src/transformers/models/video_llama_3/processing_video_llama_3.py index 4e4835e21fb3..7916d7e41d8e 100644 --- a/src/transformers/models/video_llama_3/processing_video_llama_3.py +++ b/src/transformers/models/video_llama_3/processing_video_llama_3.py @@ -22,9 +22,7 @@ from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, logging -from ...utils.import_utils import requires_backends from ...video_utils import VideoInput -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens logger = logging.get_logger(__name__) @@ -78,7 +76,6 @@ def __call__( - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. """ - return_extra_tensors = kwargs.pop("return_extra_tensors", False) output_kwargs = self._merge_kwargs( VideoLlama3ProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, @@ -90,10 +87,6 @@ def __call__( image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) image_grid_thw = image_inputs["image_grid_thw"] image_merge_sizes = image_inputs["image_merge_sizes"] - if return_extra_tensors: - requires_backends(self, ["torch"]) - image_inputs["image_cu_seqlens"] = get_vision_cu_seqlens(image_grid_thw) - image_inputs["image_rotary_pos_ids"] = get_rotary_pos_ids(image_grid_thw, image_merge_sizes) else: image_grid_thw = image_merge_sizes = [] From 9ea6203d29f5d4f107d2ddfd1f3f525a8d8a6a42 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 21 Apr 2026 11:15:53 +0200 Subject: [PATCH 42/56] renames --- .../modeling_ernie4_5_vl_moe.py | 52 +++++------- .../modular_ernie4_5_vl_moe.py | 18 ++--- .../models/glm46v/modeling_glm46v.py | 24 ++---- .../models/glm4v/modeling_glm4v.py | 54 ++++++------- .../models/glm4v/modular_glm4v.py | 30 +++---- .../models/glm4v_moe/modeling_glm4v_moe.py | 54 ++++++------- .../models/glm_image/modeling_glm_image.py | 28 +++---- .../models/glm_image/modular_glm_image.py | 28 +++---- .../models/glm_ocr/modeling_glm_ocr.py | 50 +++++------- .../models/glm_ocr/modular_glm_ocr.py | 14 ++-- .../paddleocr_vl/modeling_paddleocr_vl.py | 34 ++++---- .../paddleocr_vl/modular_paddleocr_vl.py | 24 +++--- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 44 +++++----- .../qwen2_5_omni/modular_qwen2_5_omni.py | 28 +++---- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 72 +++++++---------- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 40 +++++----- .../models/qwen2_vl/modeling_qwen2_vl.py | 80 +++++++------------ .../models/qwen3_5/modeling_qwen3_5.py | 64 +++++++-------- .../models/qwen3_5/modular_qwen3_5.py | 32 ++++---- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 64 +++++++-------- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 52 ++++++------ .../qwen3_omni_moe/modular_qwen3_omni_moe.py | 8 +- .../models/qwen3_vl/modeling_qwen3_vl.py | 72 ++++++++--------- .../models/qwen3_vl/modular_qwen3_vl.py | 56 ++++++------- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 72 ++++++++--------- .../video_llama_3/modeling_video_llama_3.py | 34 ++++---- .../video_llama_3/modular_video_llama_3.py | 30 +++---- src/transformers/utils/auto_docstring.py | 4 +- src/transformers/vision_utils.py | 20 ++--- 29 files changed, 542 insertions(+), 640 deletions(-) diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index a578f08b2639..a2d3e4350bda 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -42,7 +42,7 @@ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_ernie4_5_vl_moe import Ernie4_5_VLMoeConfig, Ernie4_5_VLMoeTextConfig, Ernie4_5_VLMoeVisionConfig @@ -857,8 +857,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: - return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) @auto_docstring @@ -900,7 +900,7 @@ def forward( hidden_states: torch.Tensor, grid_thw: torch.Tensor, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -908,15 +908,15 @@ def forward( The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values. cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). """ hidden_states = self.patch_embed(hidden_states) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -935,12 +935,12 @@ def forward( def rot_pos_emb(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", FutureWarning, stacklevel=2, ) - pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(pos_ids) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) return rotary_pos_emb @@ -1257,7 +1257,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1265,16 +1265,12 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from the processor). - video_rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed rotary position IDs (from the processor). """ video_outputs = self.vision_tower( pixel_values_videos, video_grid_thw, cu_seqlens=video_cu_seqlens, - rotary_pos_ids=video_rotary_pos_ids, + position_ids=video_position_ids, return_dict=True, **kwargs, ) @@ -1295,7 +1291,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1303,16 +1299,12 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from the processor). - image_rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed rotary position IDs (from the processor). """ image_outputs = self.vision_tower( pixel_values, image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, return_dict=True, **kwargs, ) @@ -1604,7 +1596,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1614,14 +1606,12 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). - video_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, video_grid_thw, video_cu_seqlens, - video_rotary_pos_ids, + video_position_ids, **kwargs, ) @@ -1631,7 +1621,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1641,14 +1631,12 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). - image_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, image_grid_thw, image_cu_seqlens, - image_rotary_pos_ids, + image_position_ids, **kwargs, ) diff --git a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py index d2c601f26068..4a73baec12da 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py @@ -52,7 +52,7 @@ ) from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from ..ernie4_5_moe.configuration_ernie4_5_moe import Ernie4_5_MoeConfig from ..ernie4_5_moe.modeling_ernie4_5_moe import ( Ernie4_5_MoeAttention, @@ -703,15 +703,15 @@ def forward( hidden_states: torch.Tensor, grid_thw: torch.Tensor, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: hidden_states = self.patch_embed(hidden_states) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -971,14 +971,14 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: video_outputs = self.vision_tower( pixel_values_videos, video_grid_thw, cu_seqlens=video_cu_seqlens, - rotary_pos_ids=video_rotary_pos_ids, + position_ids=video_position_ids, return_dict=True, **kwargs, ) @@ -999,14 +999,14 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: image_outputs = self.vision_tower( pixel_values, image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, return_dict=True, **kwargs, ) diff --git a/src/transformers/models/glm46v/modeling_glm46v.py b/src/transformers/models/glm46v/modeling_glm46v.py index 5e243a2536c8..4505927a7254 100644 --- a/src/transformers/models/glm46v/modeling_glm46v.py +++ b/src/transformers/models/glm46v/modeling_glm46v.py @@ -261,7 +261,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -282,7 +282,7 @@ def get_video_features( pixel_values_videos, grid_thw=flattened_video_grid_thw, cu_seqlens=video_cu_seqlens, - rotary_pos_ids=video_rotary_pos_ids, + position_ids=video_position_ids, return_dict=True, **kwargs, ) @@ -299,7 +299,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -307,17 +307,13 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from the processor). - image_rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed rotary position IDs (from the processor). """ pixel_values = pixel_values.type(self.visual.dtype) vision_outputs = self.visual( pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, **kwargs, ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -539,7 +535,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -549,14 +545,12 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). - video_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, video_grid_thw, video_cu_seqlens, - video_rotary_pos_ids, + video_position_ids, **kwargs, ) @@ -566,7 +560,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -576,14 +570,12 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). - image_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, image_grid_thw, image_cu_seqlens, - image_rotary_pos_ids, + image_position_ids, **kwargs, ) diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 8a902370e39b..77a181c4e54d 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -43,7 +43,7 @@ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_glm4v import Glm4vConfig, Glm4vTextConfig, Glm4vVisionConfig @@ -112,8 +112,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: - return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Glm4vVisionPatchMerger(nn.Module): @@ -734,7 +734,7 @@ def forward( hidden_states: torch.Tensor, grid_thw: torch.Tensor, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -744,8 +744,8 @@ def forward( The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + position_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor`: hidden_states. @@ -753,13 +753,13 @@ def forward( hidden_states = self.patch_embed(hidden_states) hidden_states = self.post_conv_layernorm(hidden_states) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) - rotary_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -768,8 +768,8 @@ def forward( hidden_states, seqlens, grid_thw, - rotary_pos_ids[:, 0].to(hidden_states.device), - rotary_pos_ids[:, 1].to(hidden_states.device), + position_ids[:, 0].to(hidden_states.device), + position_ids[:, 1].to(hidden_states.device), ) for blk in self.blocks: @@ -797,13 +797,13 @@ def forward( def rot_pos_emb(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", FutureWarning, stacklevel=2, ) - pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(pos_ids) - return rotary_pos_emb, pos_ids + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb, position_ids @auto_docstring @@ -1091,7 +1091,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1112,7 +1112,7 @@ def get_video_features( pixel_values_videos, grid_thw=flattened_video_grid_thw, cu_seqlens=video_cu_seqlens, - rotary_pos_ids=video_rotary_pos_ids, + position_ids=video_position_ids, return_dict=True, **kwargs, ) @@ -1129,7 +1129,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1137,17 +1137,13 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from the processor). - image_rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed rotary position IDs (from the processor). """ pixel_values = pixel_values.type(self.visual.dtype) vision_outputs = self.visual( pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, **kwargs, ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -1369,7 +1365,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1379,14 +1375,12 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). - video_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, video_grid_thw, video_cu_seqlens, - video_rotary_pos_ids, + video_position_ids, **kwargs, ) @@ -1396,7 +1390,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1406,14 +1400,12 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). - image_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, image_grid_thw, image_cu_seqlens, - image_rotary_pos_ids, + image_position_ids, **kwargs, ) diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index a3b754463b1d..121f02e6888c 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -45,7 +45,7 @@ from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ...video_utils import VideoInput -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from ..glm4.modeling_glm4 import Glm4MLP, Glm4RMSNorm, Glm4RotaryEmbedding, eager_attention_forward from ..qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VisionPatchEmbed, @@ -618,7 +618,7 @@ def forward( hidden_states: torch.Tensor, grid_thw: torch.Tensor, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -628,8 +628,8 @@ def forward( The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + position_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor`: hidden_states. @@ -637,13 +637,13 @@ def forward( hidden_states = self.patch_embed(hidden_states) hidden_states = self.post_conv_layernorm(hidden_states) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) - rotary_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -652,8 +652,8 @@ def forward( hidden_states, seqlens, grid_thw, - rotary_pos_ids[:, 0].to(hidden_states.device), - rotary_pos_ids[:, 1].to(hidden_states.device), + position_ids[:, 0].to(hidden_states.device), + position_ids[:, 1].to(hidden_states.device), ) for blk in self.blocks: @@ -681,13 +681,13 @@ def forward( def rot_pos_emb(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", FutureWarning, stacklevel=2, ) - pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(pos_ids) - return rotary_pos_emb, pos_ids + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb, position_ids class Glm4vTextModel(Qwen2_5_VLTextModel): @@ -800,7 +800,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -821,7 +821,7 @@ def get_video_features( pixel_values_videos, grid_thw=flattened_video_grid_thw, cu_seqlens=video_cu_seqlens, - rotary_pos_ids=video_rotary_pos_ids, + position_ids=video_position_ids, return_dict=True, **kwargs, ) diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 63aade7b890a..f7ebbbc0a562 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -43,7 +43,7 @@ from ...utils import TransformersKwargs, auto_docstring, torch_compilable_check from ...utils.generic import can_return_tuple, is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_glm4v_moe import Glm4vMoeConfig, Glm4vMoeTextConfig, Glm4vMoeVisionConfig @@ -471,8 +471,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: - return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) @use_kernel_forward_from_hub("RMSNorm") @@ -798,7 +798,7 @@ def forward( hidden_states: torch.Tensor, grid_thw: torch.Tensor, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -808,8 +808,8 @@ def forward( The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + position_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor`: hidden_states. @@ -817,13 +817,13 @@ def forward( hidden_states = self.patch_embed(hidden_states) hidden_states = self.post_conv_layernorm(hidden_states) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) - rotary_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -832,8 +832,8 @@ def forward( hidden_states, seqlens, grid_thw, - rotary_pos_ids[:, 0].to(hidden_states.device), - rotary_pos_ids[:, 1].to(hidden_states.device), + position_ids[:, 0].to(hidden_states.device), + position_ids[:, 1].to(hidden_states.device), ) for blk in self.blocks: @@ -861,13 +861,13 @@ def forward( def rot_pos_emb(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", FutureWarning, stacklevel=2, ) - pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(pos_ids) - return rotary_pos_emb, pos_ids + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb, position_ids class Glm4vMoeTextRotaryEmbedding(nn.Module): @@ -1260,7 +1260,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1281,7 +1281,7 @@ def get_video_features( pixel_values_videos, grid_thw=flattened_video_grid_thw, cu_seqlens=video_cu_seqlens, - rotary_pos_ids=video_rotary_pos_ids, + position_ids=video_position_ids, return_dict=True, **kwargs, ) @@ -1298,7 +1298,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1306,17 +1306,13 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from the processor). - image_rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed rotary position IDs (from the processor). """ pixel_values = pixel_values.type(self.visual.dtype) vision_outputs = self.visual( pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, **kwargs, ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -1593,7 +1589,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1603,14 +1599,12 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). - video_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, video_grid_thw, video_cu_seqlens, - video_rotary_pos_ids, + video_position_ids, **kwargs, ) @@ -1620,7 +1614,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1630,14 +1624,12 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). - image_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, image_grid_thw, image_cu_seqlens, - image_rotary_pos_ids, + image_position_ids, **kwargs, ) diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py index f3510aa3247e..940e4d5a5f94 100644 --- a/src/transformers/models/glm_image/modeling_glm_image.py +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -41,7 +41,7 @@ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_glm_image import GlmImageConfig, GlmImageTextConfig, GlmImageVisionConfig, GlmImageVQVAEConfig @@ -598,7 +598,7 @@ def forward( pixel_values: torch.Tensor, grid_thw: torch.Tensor, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -608,16 +608,16 @@ def forward( The temporal, height and width of feature shape of each image. cu_seqlens (`torch.Tensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + position_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states. """ hidden_states = self.patch_embed(pixel_values) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -627,8 +627,8 @@ def forward( hidden_states, seqlens, grid_thw, - rotary_pos_ids[:, 0].to(hidden_states.device), - rotary_pos_ids[:, 1].to(hidden_states.device), + position_ids[:, 0].to(hidden_states.device), + position_ids[:, 1].to(hidden_states.device), ) # Transformer blocks (no position_embeddings needed, already added above) @@ -642,11 +642,11 @@ def forward( def rot_pos_emb(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", FutureWarning, stacklevel=2, ) - return get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + return get_vision_position_ids(grid_thw, self.spatial_merge_size) @use_kernel_forward_from_hub("RMSNorm") @@ -1164,8 +1164,8 @@ def get_rope_index( if all_decode_position_ids: max_decode_len = max(x.shape[1] for x in all_decode_position_ids) padded_decode_pos_ids = [ - F.pad(pos_ids, (0, max_decode_len - pos_ids.shape[1]), mode="replicate") - for pos_ids in all_decode_position_ids + F.pad(position_ids, (0, max_decode_len - position_ids.shape[1]), mode="replicate") + for position_ids in all_decode_position_ids ] self._cached_decode_position_ids = torch.stack(padded_decode_pos_ids, dim=0) # [batch, 3, max_decode_len] else: @@ -1182,7 +1182,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1196,7 +1196,7 @@ def get_image_features( pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, return_dict=True, **kwargs, ) diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index d5d7134f508b..ee533de4dbc0 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -35,7 +35,7 @@ from ...utils.generic import merge_with_config_defaults from ...utils.import_utils import requires from ...utils.output_capturing import capture_outputs -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from ..chameleon.modeling_chameleon import ChameleonVQVAE, ChameleonVQVAEModelOutput, ChameleonVQVAEVectorQuantizer from ..glm4v.configuration_glm4v import Glm4vTextConfig, Glm4vVisionConfig from ..glm4v.modeling_glm4v import ( @@ -438,7 +438,7 @@ def forward( pixel_values: torch.Tensor, grid_thw: torch.Tensor, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -448,16 +448,16 @@ def forward( The temporal, height and width of feature shape of each image. cu_seqlens (`torch.Tensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + position_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states. """ hidden_states = self.patch_embed(pixel_values) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -467,8 +467,8 @@ def forward( hidden_states, seqlens, grid_thw, - rotary_pos_ids[:, 0].to(hidden_states.device), - rotary_pos_ids[:, 1].to(hidden_states.device), + position_ids[:, 0].to(hidden_states.device), + position_ids[:, 1].to(hidden_states.device), ) # Transformer blocks (no position_embeddings needed, already added above) @@ -482,11 +482,11 @@ def forward( def rot_pos_emb(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", FutureWarning, stacklevel=2, ) - return get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + return get_vision_position_ids(grid_thw, self.spatial_merge_size) class GlmImageTextModel(Glm4vTextModel): @@ -667,8 +667,8 @@ def get_rope_index( if all_decode_position_ids: max_decode_len = max(x.shape[1] for x in all_decode_position_ids) padded_decode_pos_ids = [ - F.pad(pos_ids, (0, max_decode_len - pos_ids.shape[1]), mode="replicate") - for pos_ids in all_decode_position_ids + F.pad(position_ids, (0, max_decode_len - position_ids.shape[1]), mode="replicate") + for position_ids in all_decode_position_ids ] self._cached_decode_position_ids = torch.stack(padded_decode_pos_ids, dim=0) # [batch, 3, max_decode_len] else: @@ -719,7 +719,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -733,7 +733,7 @@ def get_image_features( pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, return_dict=True, **kwargs, ) diff --git a/src/transformers/models/glm_ocr/modeling_glm_ocr.py b/src/transformers/models/glm_ocr/modeling_glm_ocr.py index e1702e136461..5925675fd7bc 100644 --- a/src/transformers/models/glm_ocr/modeling_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modeling_glm_ocr.py @@ -43,7 +43,7 @@ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_glm_ocr import GlmOcrConfig, GlmOcrTextConfig, GlmOcrVisionConfig @@ -313,8 +313,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: - return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) @auto_docstring @@ -586,7 +586,7 @@ def forward( hidden_states: torch.Tensor, grid_thw: torch.Tensor, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: r""" @@ -596,21 +596,21 @@ def forward( The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + position_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) - rotary_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -637,13 +637,13 @@ def forward( def rot_pos_emb(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", FutureWarning, stacklevel=2, ) - pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(pos_ids) - return rotary_pos_emb, pos_ids + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb, position_ids class GlmOcrTextRotaryEmbedding(nn.Module): @@ -1008,7 +1008,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1029,7 +1029,7 @@ def get_video_features( pixel_values_videos, grid_thw=flattened_video_grid_thw, cu_seqlens=video_cu_seqlens, - rotary_pos_ids=video_rotary_pos_ids, + position_ids=video_position_ids, return_dict=True, **kwargs, ) @@ -1046,7 +1046,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1054,17 +1054,13 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from the processor). - image_rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed rotary position IDs (from the processor). """ pixel_values = pixel_values.type(self.visual.dtype) vision_outputs = self.visual( pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, **kwargs, ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -1286,7 +1282,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1296,14 +1292,12 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). - video_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, video_grid_thw, video_cu_seqlens, - video_rotary_pos_ids, + video_position_ids, **kwargs, ) @@ -1313,7 +1307,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1323,14 +1317,12 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). - image_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, image_grid_thw, image_cu_seqlens, - image_rotary_pos_ids, + image_position_ids, **kwargs, ) diff --git a/src/transformers/models/glm_ocr/modular_glm_ocr.py b/src/transformers/models/glm_ocr/modular_glm_ocr.py index e37f37d05a60..63812ab0d255 100644 --- a/src/transformers/models/glm_ocr/modular_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modular_glm_ocr.py @@ -21,7 +21,7 @@ from ...modeling_outputs import BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...utils import auto_docstring -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from ..glm4v.configuration_glm4v import Glm4vConfig, Glm4vTextConfig, Glm4vVisionConfig from ..glm4v.modeling_glm4v import ( Glm4vForConditionalGeneration, @@ -252,7 +252,7 @@ def forward( hidden_states: torch.Tensor, grid_thw: torch.Tensor, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: r""" @@ -262,21 +262,21 @@ def forward( The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + position_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) - rotary_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 075cbe6d1a69..716fa78b15dd 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -46,7 +46,7 @@ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check, torch_int from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_paddleocr_vl import PaddleOCRTextConfig, PaddleOCRVisionConfig, PaddleOCRVLConfig @@ -99,8 +99,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: - return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class PaddleOCRRotaryEmbedding(nn.Module): @@ -822,7 +822,7 @@ def forward( grid_thw: torch.LongTensor | None = None, cu_seqlens: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: r""" @@ -836,7 +836,7 @@ def forward( The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - rotary_pos_ids (`torch.Tensor` of shape `(sequence_length, 2)`, *optional*): + position_ids (`torch.Tensor` of shape `(sequence_length, 2)`, *optional*): Precomputed rotary position ids as `(row, column)` pairs. If not provided, will be computed based on `image_grid_thw`. """ @@ -847,12 +847,12 @@ def forward( attention_mask=attention_mask, ) - if rotary_pos_ids is None: + if position_ids is None: # Use merge_size=1: PaddleOCR merges patches in the projector (after the encoder), # unlike Qwen which merges inside the encoder, so rotary positions here are simple (row, col). - rotary_pos_ids = get_rotary_pos_ids(grid_thw, 1) + position_ids = get_vision_position_ids(grid_thw, 1) - rotary_embeddings = self.rotary_pos_emb(rotary_pos_ids) + rotary_embeddings = self.rotary_pos_emb(position_ids) rotary_embeddings = rotary_embeddings.repeat(1, 2) position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin()) @@ -900,7 +900,7 @@ def forward( grid_thw: torch.LongTensor | None = None, cu_seqlens: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: """ @@ -920,7 +920,7 @@ def forward( grid_thw=grid_thw, cu_seqlens=cu_seqlens, attention_mask=attention_mask, - rotary_pos_ids=rotary_pos_ids, + position_ids=position_ids, **kwargs, ) @@ -951,7 +951,7 @@ def forward( pixel_values: torch.FloatTensor, grid_thw: torch.LongTensor | None = None, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ @@ -967,7 +967,7 @@ def forward( pixel_values=pixel_values, grid_thw=grid_thw, cu_seqlens=cu_seqlens, - rotary_pos_ids=rotary_pos_ids, + position_ids=position_ids, **kwargs, ) @@ -1217,7 +1217,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1231,7 +1231,7 @@ def get_image_features( pixel_values=pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, return_dict=True, **kwargs, ) @@ -1398,7 +1398,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1408,14 +1408,12 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). - image_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, image_grid_thw, image_cu_seqlens, - image_rotary_pos_ids, + image_position_ids, **kwargs, ) diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index 268e34616983..76c22a6fe116 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -57,7 +57,7 @@ ) from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from ..ernie4_5.configuration_ernie4_5 import Ernie4_5Config from ..ernie4_5.modeling_ernie4_5 import ( Ernie4_5DecoderLayer, @@ -790,7 +790,7 @@ def forward( grid_thw: torch.LongTensor | None = None, cu_seqlens: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: r""" @@ -804,7 +804,7 @@ def forward( The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - rotary_pos_ids (`torch.Tensor` of shape `(sequence_length, 2)`, *optional*): + position_ids (`torch.Tensor` of shape `(sequence_length, 2)`, *optional*): Precomputed rotary position ids as `(row, column)` pairs. If not provided, will be computed based on `image_grid_thw`. """ @@ -815,12 +815,12 @@ def forward( attention_mask=attention_mask, ) - if rotary_pos_ids is None: + if position_ids is None: # Use merge_size=1: PaddleOCR merges patches in the projector (after the encoder), # unlike Qwen which merges inside the encoder, so rotary positions here are simple (row, col). - rotary_pos_ids = get_rotary_pos_ids(grid_thw, 1) + position_ids = get_vision_position_ids(grid_thw, 1) - rotary_embeddings = self.rotary_pos_emb(rotary_pos_ids) + rotary_embeddings = self.rotary_pos_emb(position_ids) rotary_embeddings = rotary_embeddings.repeat(1, 2) position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin()) @@ -868,7 +868,7 @@ def forward( grid_thw: torch.LongTensor | None = None, cu_seqlens: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: """ @@ -888,7 +888,7 @@ def forward( grid_thw=grid_thw, cu_seqlens=cu_seqlens, attention_mask=attention_mask, - rotary_pos_ids=rotary_pos_ids, + position_ids=position_ids, **kwargs, ) @@ -919,7 +919,7 @@ def forward( pixel_values: torch.FloatTensor, grid_thw: torch.LongTensor | None = None, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ @@ -935,7 +935,7 @@ def forward( pixel_values=pixel_values, grid_thw=grid_thw, cu_seqlens=cu_seqlens, - rotary_pos_ids=rotary_pos_ids, + position_ids=position_ids, **kwargs, ) @@ -976,7 +976,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -990,7 +990,7 @@ def get_image_features( pixel_values=pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, return_dict=True, **kwargs, ) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 89f648c69265..39dff932bff2 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -55,7 +55,7 @@ from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.hub import cached_file from ...utils.output_capturing import capture_outputs -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens, get_window_index +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids, get_vision_window_index from .configuration_qwen2_5_omni import ( Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniBigVGANConfig, @@ -1165,8 +1165,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: - return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Qwen2_5_VisionPatchEmbed(nn.Module): @@ -1257,7 +1257,7 @@ def forward( cu_seqlens: torch.Tensor | None = None, window_index: torch.Tensor | None = None, cu_window_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ @@ -1269,27 +1269,27 @@ def forward( cu_seqlens (`torch.Tensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). cu_window_seqlens (`torch.Tensor`, *optional*): - Precomputed window cumulative sequence lengths (from `get_window_index`). + Precomputed window cumulative sequence lengths (from `get_vision_window_index`). window_index (`torch.Tensor`, *optional*): - Precomputed window reordering index (from `get_window_index`). - rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + Precomputed window reordering index (from `get_vision_window_index`). + position_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(position_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) if window_index is None: - window_index, cu_window_seqlens = get_window_index( + window_index, cu_window_seqlens = get_vision_window_index( grid_thw, self.spatial_merge_size, self.window_size, self.patch_size, self.spatial_merge_unit ) @@ -1326,21 +1326,21 @@ def forward( def rot_pos_emb(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", FutureWarning, stacklevel=2, ) - pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(pos_ids) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) return rotary_pos_emb - def get_window_index(self, grid_thw): + def get_vision_window_index(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.get_window_index` is deprecated and will be removed in a future version. Use `get_window_index` from `transformers.vision_utils` instead.", + f"`{self.__class__.__name__}.get_vision_window_index` is deprecated and will be removed in a future version. Use `get_vision_window_index` from `transformers.vision_utils` instead.", FutureWarning, stacklevel=2, ) - window_index, cu_window_seqlens = get_window_index( + window_index, cu_window_seqlens = get_vision_window_index( grid_thw, self.spatial_merge_size, self.window_size, @@ -1779,7 +1779,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1793,7 +1793,7 @@ def get_video_features( pixel_values_videos, grid_thw=video_grid_thw, cu_seqlens=video_cu_seqlens, - rotary_pos_ids=video_rotary_pos_ids, + position_ids=video_position_ids, **kwargs, ) @@ -1804,7 +1804,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1818,7 +1818,7 @@ def get_image_features( pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, **kwargs, ) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 90abc7a27cf7..9c54b35d6b2d 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -47,7 +47,7 @@ from ...utils.generic import is_flash_attention_requested, merge_with_config_defaults from ...utils.hub import cached_file from ...utils.output_capturing import capture_outputs -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens, get_window_index +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids, get_vision_window_index from ..llama.modeling_llama import LlamaRotaryEmbedding, rotate_half from ..qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig from ..qwen2_5_vl.modeling_qwen2_5_vl import ( @@ -1612,7 +1612,7 @@ def forward( cu_seqlens: torch.Tensor | None = None, window_index: torch.Tensor | None = None, cu_window_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ @@ -1624,27 +1624,27 @@ def forward( cu_seqlens (`torch.Tensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). cu_window_seqlens (`torch.Tensor`, *optional*): - Precomputed window cumulative sequence lengths (from `get_window_index`). + Precomputed window cumulative sequence lengths (from `get_vision_window_index`). window_index (`torch.Tensor`, *optional*): - Precomputed window reordering index (from `get_window_index`). - rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + Precomputed window reordering index (from `get_vision_window_index`). + position_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(position_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) if window_index is None: - window_index, cu_window_seqlens = get_window_index( + window_index, cu_window_seqlens = get_vision_window_index( grid_thw, self.spatial_merge_size, self.window_size, self.patch_size, self.spatial_merge_unit ) @@ -1764,7 +1764,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1778,7 +1778,7 @@ def get_video_features( pixel_values_videos, grid_thw=video_grid_thw, cu_seqlens=video_cu_seqlens, - rotary_pos_ids=video_rotary_pos_ids, + position_ids=video_position_ids, **kwargs, ) @@ -1789,7 +1789,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1803,7 +1803,7 @@ def get_image_features( pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, **kwargs, ) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 9e54f36e06fe..f93e3ffa1061 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -47,7 +47,7 @@ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens, get_window_index +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids, get_vision_window_index from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig @@ -125,8 +125,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: - return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Qwen2_5_VLPatchMerger(nn.Module): @@ -387,7 +387,7 @@ def forward( cu_seqlens: torch.Tensor | None = None, window_index: torch.Tensor | None = None, cu_window_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ @@ -399,27 +399,27 @@ def forward( cu_seqlens (`torch.Tensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). window_index (`torch.Tensor`, *optional*): - Precomputed window reordering index (from `get_window_index`). + Precomputed window reordering index (from `get_vision_window_index`). cu_window_seqlens (`torch.Tensor`, *optional*): - Precomputed window cumulative sequence lengths (from `get_window_index`). - rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + Precomputed window cumulative sequence lengths (from `get_vision_window_index`). + position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(position_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) if window_index is None: - window_index, cu_window_seqlens = get_window_index( + window_index, cu_window_seqlens = get_vision_window_index( grid_thw, self.spatial_merge_size, self.window_size, self.patch_size, self.spatial_merge_unit ) @@ -458,21 +458,21 @@ def forward( def rot_pos_emb(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", FutureWarning, stacklevel=2, ) - pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(pos_ids) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) return rotary_pos_emb - def get_window_index(self, grid_thw): + def get_vision_window_index(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.get_window_index` is deprecated and will be removed in a future version. Use `get_window_index` from `transformers.vision_utils` instead.", + f"`{self.__class__.__name__}.get_vision_window_index` is deprecated and will be removed in a future version. Use `get_vision_window_index` from `transformers.vision_utils` instead.", FutureWarning, stacklevel=2, ) - window_index, cu_window_seqlens = get_window_index( + window_index, cu_window_seqlens = get_vision_window_index( grid_thw, self.spatial_merge_size, self.window_size, @@ -1110,7 +1110,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1118,17 +1118,13 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from the processor). - video_rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed rotary position IDs (from the processor). """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) vision_outputs = self.visual( pixel_values_videos, grid_thw=video_grid_thw, cu_seqlens=video_cu_seqlens, - rotary_pos_ids=video_rotary_pos_ids, + position_ids=video_position_ids, **kwargs, ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -1144,7 +1140,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1152,17 +1148,13 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from the processor). - image_rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed rotary position IDs (from the processor). """ pixel_values = pixel_values.type(self.visual.dtype) vision_outputs = self.visual( pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, **kwargs, ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -1278,9 +1270,9 @@ def forward( mm_token_type_ids: torch.IntTensor | None = None, second_per_grid_ts: torch.Tensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen2_5_VLModelOutputWithPast: r""" @@ -1302,7 +1294,7 @@ def forward( pixel_values, image_grid_thw, image_cu_seqlens, - image_rotary_pos_ids, + image_position_ids, ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( @@ -1315,7 +1307,7 @@ def forward( pixel_values_videos, video_grid_thw, video_cu_seqlens, - video_rotary_pos_ids, + video_position_ids, ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask( @@ -1407,7 +1399,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1417,14 +1409,12 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). - video_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, video_grid_thw, video_cu_seqlens, - video_rotary_pos_ids, + video_position_ids, **kwargs, ) @@ -1434,7 +1424,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1444,14 +1434,12 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). - image_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, image_grid_thw, image_cu_seqlens, - image_rotary_pos_ids, + image_position_ids, **kwargs, ) diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index d3c3a3987d71..1b38a12d8e30 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -40,7 +40,7 @@ from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ...video_utils import VideoInput -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens, get_window_index +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids, get_vision_window_index from ..llama.modeling_llama import LlamaRMSNorm from ..qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig from ..qwen2_vl.modeling_qwen2_vl import ( @@ -228,7 +228,7 @@ def forward( cu_seqlens: torch.Tensor | None = None, window_index: torch.Tensor | None = None, cu_window_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ @@ -240,27 +240,27 @@ def forward( cu_seqlens (`torch.Tensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). window_index (`torch.Tensor`, *optional*): - Precomputed window reordering index (from `get_window_index`). + Precomputed window reordering index (from `get_vision_window_index`). cu_window_seqlens (`torch.Tensor`, *optional*): - Precomputed window cumulative sequence lengths (from `get_window_index`). - rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + Precomputed window cumulative sequence lengths (from `get_vision_window_index`). + position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(position_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) if window_index is None: - window_index, cu_window_seqlens = get_window_index( + window_index, cu_window_seqlens = get_vision_window_index( grid_thw, self.spatial_merge_size, self.window_size, self.patch_size, self.spatial_merge_unit ) @@ -299,21 +299,21 @@ def forward( def rot_pos_emb(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", FutureWarning, stacklevel=2, ) - pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(pos_ids) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) return rotary_pos_emb - def get_window_index(self, grid_thw): + def get_vision_window_index(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.get_window_index` is deprecated and will be removed in a future version. Use `get_window_index` from `transformers.vision_utils` instead.", + f"`{self.__class__.__name__}.get_vision_window_index` is deprecated and will be removed in a future version. Use `get_vision_window_index` from `transformers.vision_utils` instead.", FutureWarning, stacklevel=2, ) - window_index, cu_window_seqlens = get_window_index( + window_index, cu_window_seqlens = get_vision_window_index( grid_thw, self.spatial_merge_size, self.window_size, @@ -520,9 +520,9 @@ def forward( mm_token_type_ids: torch.IntTensor | None = None, second_per_grid_ts: torch.Tensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen2_5_VLModelOutputWithPast: r""" @@ -544,7 +544,7 @@ def forward( pixel_values, image_grid_thw, image_cu_seqlens, - image_rotary_pos_ids, + image_position_ids, ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( @@ -557,7 +557,7 @@ def forward( pixel_values_videos, video_grid_thw, video_cu_seqlens, - video_rotary_pos_ids, + video_position_ids, ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask( diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 32897705ad41..d476e452a8b7 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -53,7 +53,7 @@ merge_with_config_defaults, ) from ...utils.output_capturing import capture_outputs -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig, Qwen2VLVisionConfig @@ -279,8 +279,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: - return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class PatchEmbed(nn.Module): @@ -729,7 +729,7 @@ def forward( hidden_states: torch.Tensor, grid_thw: torch.Tensor, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: r""" @@ -737,15 +737,15 @@ def forward( The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values. cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). """ hidden_states = self.patch_embed(hidden_states) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -770,12 +770,12 @@ def forward( def rot_pos_emb(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", FutureWarning, stacklevel=2, ) - pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(pos_ids) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) return rotary_pos_emb @@ -1084,7 +1084,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1092,17 +1092,13 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from the processor). - video_rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed rotary position IDs (from the processor). """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) vision_outputs = self.visual( pixel_values_videos, grid_thw=video_grid_thw, cu_seqlens=video_cu_seqlens, - rotary_pos_ids=video_rotary_pos_ids, + position_ids=video_position_ids, **kwargs, ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -1118,7 +1114,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1126,17 +1122,13 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from the processor). - image_rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed rotary position IDs (from the processor). """ pixel_values = pixel_values.type(self.visual.dtype) vision_outputs = self.visual( pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, **kwargs, ) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -1252,9 +1244,9 @@ def forward( rope_deltas: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen2VLModelOutputWithPast: r""" @@ -1264,14 +1256,6 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. - image_cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths for images (from the processor). - image_rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed rotary position IDs for images (from the processor). - video_cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from the processor). - video_rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed rotary position IDs for videos (from the processor). """ if inputs_embeds is None: @@ -1282,7 +1266,7 @@ def forward( pixel_values, image_grid_thw, image_cu_seqlens, - image_rotary_pos_ids, + image_position_ids, ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( @@ -1295,7 +1279,7 @@ def forward( pixel_values_videos, video_grid_thw, video_cu_seqlens, - video_rotary_pos_ids, + video_position_ids, ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask( @@ -1354,7 +1338,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1364,14 +1348,12 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). - video_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, video_grid_thw, video_cu_seqlens, - video_rotary_pos_ids, + video_position_ids, **kwargs, ) @@ -1381,7 +1363,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1391,14 +1373,12 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). - image_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, image_grid_thw, image_cu_seqlens, - image_rotary_pos_ids, + image_position_ids, **kwargs, ) @@ -1420,9 +1400,9 @@ def forward( rope_deltas: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen2VLCausalLMOutputWithPast: @@ -1439,12 +1419,8 @@ def forward( The rope index difference between sequence length and multimodal rope. image_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). - image_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). video_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). - video_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). Example: @@ -1491,9 +1467,9 @@ def forward( video_grid_thw=video_grid_thw, mm_token_type_ids=mm_token_type_ids, image_cu_seqlens=image_cu_seqlens, - image_rotary_pos_ids=image_rotary_pos_ids, + image_position_ids=image_position_ids, video_cu_seqlens=video_cu_seqlens, - video_rotary_pos_ids=video_rotary_pos_ids, + video_position_ids=video_position_ids, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index 51922261511a..d506a024edca 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -49,7 +49,7 @@ from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available from ...utils.output_capturing import capture_outputs -from ...vision_utils import get_pos_embed_indices, get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids from .configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig, Qwen3_5VisionConfig @@ -78,8 +78,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: - return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Qwen3_5TextRotaryEmbedding(nn.Module): @@ -1034,8 +1034,8 @@ def forward( hidden_states: torch.Tensor, grid_thw: torch.Tensor, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, - embed_indices: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + bilinear_indices: torch.Tensor | None = None, bilinear_weights: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: @@ -1047,29 +1047,29 @@ def forward( The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). - embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices into the position embedding table (from `get_pos_embed_indices`). + position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). + bilinear_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Bilinear corner indices for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Interpolation weights for the four bilinear corners (from `get_pos_embed_indices`). + Bilinear corner weights for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - if embed_indices is None or bilinear_weights is None: - embed_indices, bilinear_weights = get_pos_embed_indices( + if bilinear_indices is None or bilinear_weights is None: + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) - pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) + pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(position_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -1097,24 +1097,24 @@ def forward( def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", FutureWarning, stacklevel=2, ) - pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(pos_ids) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) return rotary_pos_emb def fast_pos_embed_interpolate(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_pos_embed_indices` from `transformers.vision_utils` and apply `self.pos_embed`.", + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", FutureWarning, stacklevel=2, ) - embed_indices, bilinear_weights = get_pos_embed_indices( + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) - return (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) + return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) @dataclass @@ -1420,7 +1420,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1434,7 +1434,7 @@ def get_video_features( pixel_values_videos, video_grid_thw, image_cu_seqlens=video_cu_seqlens, - image_rotary_pos_ids=video_rotary_pos_ids, + image_position_ids=video_position_ids, **kwargs, ) @@ -1445,7 +1445,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1459,7 +1459,7 @@ def get_image_features( pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, return_dict=True, **kwargs, ) @@ -1775,7 +1775,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1785,14 +1785,12 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). - video_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, video_grid_thw, video_cu_seqlens, - video_rotary_pos_ids, + video_position_ids, **kwargs, ) @@ -1802,7 +1800,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1812,14 +1810,12 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). - image_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, image_grid_thw, image_cu_seqlens, - image_rotary_pos_ids, + image_position_ids, **kwargs, ) diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index aa8aef5dc0b7..d55d1abc1e2c 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -30,7 +30,7 @@ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ...vision_utils import get_pos_embed_indices, get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids from ..qwen3.modeling_qwen3 import Qwen3ForCausalLM from ..qwen3_next.configuration_qwen3_next import Qwen3NextConfig from ..qwen3_next.modeling_qwen3_next import ( @@ -426,8 +426,8 @@ def forward( hidden_states: torch.Tensor, grid_thw: torch.Tensor, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, - embed_indices: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + bilinear_indices: torch.Tensor | None = None, bilinear_weights: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: @@ -439,29 +439,29 @@ def forward( The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). - embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices into the position embedding table (from `get_pos_embed_indices`). + position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). + bilinear_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Bilinear corner indices for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Interpolation weights for the four bilinear corners (from `get_pos_embed_indices`). + Bilinear corner weights for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - if embed_indices is None or bilinear_weights is None: - embed_indices, bilinear_weights = get_pos_embed_indices( + if bilinear_indices is None or bilinear_weights is None: + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) - pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) + pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(position_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -580,7 +580,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: pixel_values = pixel_values.type(self.visual.dtype) @@ -588,7 +588,7 @@ def get_image_features( pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, return_dict=True, **kwargs, ) diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index dd35f8887fb1..e8d78e7cb0a0 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -50,7 +50,7 @@ from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available from ...utils.output_capturing import OutputRecorder, capture_outputs -from ...vision_utils import get_pos_embed_indices, get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids from .configuration_qwen3_5_moe import Qwen3_5MoeConfig, Qwen3_5MoeTextConfig, Qwen3_5MoeVisionConfig @@ -79,8 +79,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: - return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Qwen3_5MoeTextRotaryEmbedding(nn.Module): @@ -1127,8 +1127,8 @@ def forward( hidden_states: torch.Tensor, grid_thw: torch.Tensor, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, - embed_indices: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + bilinear_indices: torch.Tensor | None = None, bilinear_weights: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: @@ -1140,29 +1140,29 @@ def forward( The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). - embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices into the position embedding table (from `get_pos_embed_indices`). + position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). + bilinear_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Bilinear corner indices for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Interpolation weights for the four bilinear corners (from `get_pos_embed_indices`). + Bilinear corner weights for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - if embed_indices is None or bilinear_weights is None: - embed_indices, bilinear_weights = get_pos_embed_indices( + if bilinear_indices is None or bilinear_weights is None: + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) - pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) + pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(position_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -1190,24 +1190,24 @@ def forward( def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", FutureWarning, stacklevel=2, ) - pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(pos_ids) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) return rotary_pos_emb def fast_pos_embed_interpolate(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_pos_embed_indices` from `transformers.vision_utils` and apply `self.pos_embed`.", + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", FutureWarning, stacklevel=2, ) - embed_indices, bilinear_weights = get_pos_embed_indices( + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) - return (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) + return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) @dataclass @@ -1545,7 +1545,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1559,7 +1559,7 @@ def get_video_features( pixel_values_videos, video_grid_thw, image_cu_seqlens=video_cu_seqlens, - image_rotary_pos_ids=video_rotary_pos_ids, + image_position_ids=video_position_ids, **kwargs, ) @@ -1570,7 +1570,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1584,7 +1584,7 @@ def get_image_features( pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, return_dict=True, **kwargs, ) @@ -1974,7 +1974,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1984,14 +1984,12 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). - video_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, video_grid_thw, video_cu_seqlens, - video_rotary_pos_ids, + video_position_ids, **kwargs, ) @@ -2001,7 +1999,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -2011,14 +2009,12 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). - image_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, image_grid_thw, image_cu_seqlens, - image_rotary_pos_ids, + image_position_ids, **kwargs, ) diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index c787f7d76bc5..eb1f261b5964 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -62,7 +62,7 @@ merge_with_config_defaults, ) from ...utils.output_capturing import OutputRecorder, capture_outputs -from ...vision_utils import get_pos_embed_indices, get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids from .configuration_qwen3_omni_moe import ( Qwen3OmniMoeAudioEncoderConfig, Qwen3OmniMoeCode2WavConfig, @@ -1030,8 +1030,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: - return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Qwen3OmniMoeTextTopKRouter(nn.Module): @@ -1174,8 +1174,8 @@ def forward( hidden_states: torch.Tensor, grid_thw: torch.Tensor, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, - embed_indices: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + bilinear_indices: torch.Tensor | None = None, bilinear_weights: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: @@ -1187,29 +1187,29 @@ def forward( The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). - embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices into the position embedding table (from `get_pos_embed_indices`). + position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). + bilinear_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Bilinear corner indices for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Interpolation weights for the four bilinear corners (from `get_pos_embed_indices`). + Bilinear corner weights for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - if embed_indices is None or bilinear_weights is None: - embed_indices, bilinear_weights = get_pos_embed_indices( + if bilinear_indices is None or bilinear_weights is None: + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) - pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) + pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(position_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -1244,24 +1244,24 @@ def forward( def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", FutureWarning, stacklevel=2, ) - pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(pos_ids) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) return rotary_pos_emb def fast_pos_embed_interpolate(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_pos_embed_indices` from `transformers.vision_utils` and apply `self.pos_embed`.", + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", FutureWarning, stacklevel=2, ) - embed_indices, bilinear_weights = get_pos_embed_indices( + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) - return (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) + return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) @property def deepstack_merger_list(self): @@ -1941,7 +1941,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1955,7 +1955,7 @@ def get_video_features( pixel_values_videos, grid_thw=video_grid_thw, cu_seqlens=video_cu_seqlens, - rotary_pos_ids=video_rotary_pos_ids, + position_ids=video_position_ids, **kwargs, ) @@ -1966,7 +1966,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1980,7 +1980,7 @@ def get_image_features( pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, **kwargs, ) diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index 6b4658e73682..2bec2d4388c4 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -1209,7 +1209,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1223,7 +1223,7 @@ def get_video_features( pixel_values_videos, grid_thw=video_grid_thw, cu_seqlens=video_cu_seqlens, - rotary_pos_ids=video_rotary_pos_ids, + position_ids=video_position_ids, **kwargs, ) @@ -1234,7 +1234,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1248,7 +1248,7 @@ def get_image_features( pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, **kwargs, ) diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 953897234a25..b6beaf0cc6a7 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -42,7 +42,7 @@ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ...vision_utils import get_pos_embed_indices, get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids from .configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig @@ -100,8 +100,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: - return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Qwen3VLVisionPatchMerger(nn.Module): @@ -665,8 +665,8 @@ def forward( hidden_states: torch.Tensor, grid_thw: torch.Tensor, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, - embed_indices: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + bilinear_indices: torch.Tensor | None = None, bilinear_weights: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: @@ -678,29 +678,29 @@ def forward( The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). - embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices into the position embedding table (from `get_pos_embed_indices`). + position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). + bilinear_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Bilinear corner indices for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Interpolation weights for the four bilinear corners (from `get_pos_embed_indices`). + Bilinear corner weights for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - if embed_indices is None or bilinear_weights is None: - embed_indices, bilinear_weights = get_pos_embed_indices( + if bilinear_indices is None or bilinear_weights is None: + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) - pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) + pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(position_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -735,24 +735,24 @@ def forward( def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", FutureWarning, stacklevel=2, ) - pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(pos_ids) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) return rotary_pos_emb def fast_pos_embed_interpolate(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_pos_embed_indices` from `transformers.vision_utils` and apply `self.pos_embed`.", + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", FutureWarning, stacklevel=2, ) - embed_indices, bilinear_weights = get_pos_embed_indices( + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) - return (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) + return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) @auto_docstring( @@ -1062,7 +1062,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1076,7 +1076,7 @@ def get_video_features( pixel_values_videos, video_grid_thw, image_cu_seqlens=video_cu_seqlens, - image_rotary_pos_ids=video_rotary_pos_ids, + image_position_ids=video_position_ids, **kwargs, ) @@ -1087,7 +1087,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1101,7 +1101,7 @@ def get_image_features( pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, return_dict=True, **kwargs, ) @@ -1217,9 +1217,9 @@ def forward( video_grid_thw: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen3VLModelOutputWithPast: r""" @@ -1242,7 +1242,7 @@ def forward( pixel_values, image_grid_thw, image_cu_seqlens, - image_rotary_pos_ids, + image_position_ids, return_dict=True, ) image_embeds = image_outputs.pooler_output @@ -1258,7 +1258,7 @@ def forward( pixel_values_videos, video_grid_thw, video_cu_seqlens, - video_rotary_pos_ids, + video_position_ids, return_dict=True, ) video_embeds = video_outputs.pooler_output @@ -1375,7 +1375,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1385,14 +1385,12 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). - video_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, video_grid_thw, video_cu_seqlens, - video_rotary_pos_ids, + video_position_ids, **kwargs, ) @@ -1402,7 +1400,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1412,14 +1410,12 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). - image_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, image_grid_thw, image_cu_seqlens, - image_rotary_pos_ids, + image_position_ids, **kwargs, ) diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index ab885673a9eb..10f6c2778bb9 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -40,7 +40,7 @@ from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ...video_utils import VideoInput -from ...vision_utils import get_pos_embed_indices, get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids from ..llama.modeling_llama import LlamaRotaryEmbedding from ..qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLCausalLMOutputWithPast, @@ -453,8 +453,8 @@ def forward( hidden_states: torch.Tensor, grid_thw: torch.Tensor, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, - embed_indices: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + bilinear_indices: torch.Tensor | None = None, bilinear_weights: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: @@ -466,29 +466,29 @@ def forward( The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). - embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices into the position embedding table (from `get_pos_embed_indices`). + position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). + bilinear_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Bilinear corner indices for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Interpolation weights for the four bilinear corners (from `get_pos_embed_indices`). + Bilinear corner weights for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - if embed_indices is None or bilinear_weights is None: - embed_indices, bilinear_weights = get_pos_embed_indices( + if bilinear_indices is None or bilinear_weights is None: + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) - pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) + pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(position_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -523,24 +523,24 @@ def forward( def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", FutureWarning, stacklevel=2, ) - pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(pos_ids) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) return rotary_pos_emb def fast_pos_embed_interpolate(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_pos_embed_indices` from `transformers.vision_utils` and apply `self.pos_embed`.", + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", FutureWarning, stacklevel=2, ) - embed_indices, bilinear_weights = get_pos_embed_indices( + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) - return (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) + return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) @auto_docstring( @@ -712,7 +712,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -726,7 +726,7 @@ def get_image_features( pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, return_dict=True, **kwargs, ) @@ -744,7 +744,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -758,7 +758,7 @@ def get_video_features( pixel_values_videos, video_grid_thw, image_cu_seqlens=video_cu_seqlens, - image_rotary_pos_ids=video_rotary_pos_ids, + image_position_ids=video_position_ids, **kwargs, ) @@ -777,9 +777,9 @@ def forward( video_grid_thw: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen3VLModelOutputWithPast: r""" @@ -802,7 +802,7 @@ def forward( pixel_values, image_grid_thw, image_cu_seqlens, - image_rotary_pos_ids, + image_position_ids, return_dict=True, ) image_embeds = image_outputs.pooler_output @@ -818,7 +818,7 @@ def forward( pixel_values_videos, video_grid_thw, video_cu_seqlens, - video_rotary_pos_ids, + video_position_ids, return_dict=True, ) video_embeds = video_outputs.pooler_output diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index a1f5125dec97..834e3cec18b7 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -48,7 +48,7 @@ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs -from ...vision_utils import get_pos_embed_indices, get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids from .configuration_qwen3_vl_moe import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig, Qwen3VLMoeVisionConfig @@ -401,8 +401,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: - return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) def apply_rotary_pos_emb_vision( @@ -650,8 +650,8 @@ def forward( hidden_states: torch.Tensor, grid_thw: torch.Tensor, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, - embed_indices: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + bilinear_indices: torch.Tensor | None = None, bilinear_weights: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: @@ -663,29 +663,29 @@ def forward( The temporal, height and width of feature shape of each image in LLM. cu_seqlens (`torch.Tensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - rotary_pos_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). - embed_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices into the position embedding table (from `get_pos_embed_indices`). + position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). + bilinear_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): + Bilinear corner indices for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Interpolation weights for the four bilinear corners (from `get_pos_embed_indices`). + Bilinear corner weights for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). Returns: `torch.Tensor`: hidden_states. """ hidden_states = self.patch_embed(hidden_states) - if embed_indices is None or bilinear_weights is None: - embed_indices, bilinear_weights = get_pos_embed_indices( + if bilinear_indices is None or bilinear_weights is None: + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) - pos_embeds = (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) + pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(position_ids) if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -720,24 +720,24 @@ def forward( def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_rotary_pos_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", FutureWarning, stacklevel=2, ) - pos_ids = get_rotary_pos_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(pos_ids) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) return rotary_pos_emb def fast_pos_embed_interpolate(self, grid_thw): warnings.warn( - f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_pos_embed_indices` from `transformers.vision_utils` and apply `self.pos_embed`.", + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", FutureWarning, stacklevel=2, ) - embed_indices, bilinear_weights = get_pos_embed_indices( + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) - return (self.pos_embed(embed_indices) * bilinear_weights[:, :, None]).sum(0) + return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) class Qwen3VLMoeTextRotaryEmbedding(nn.Module): @@ -1192,7 +1192,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1206,7 +1206,7 @@ def get_video_features( pixel_values_videos, video_grid_thw, image_cu_seqlens=video_cu_seqlens, - image_rotary_pos_ids=video_rotary_pos_ids, + image_position_ids=video_position_ids, **kwargs, ) @@ -1217,7 +1217,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1231,7 +1231,7 @@ def get_image_features( pixel_values, grid_thw=image_grid_thw, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, return_dict=True, **kwargs, ) @@ -1347,9 +1347,9 @@ def forward( video_grid_thw: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen3VLMoeModelOutputWithPast: r""" @@ -1372,7 +1372,7 @@ def forward( pixel_values, image_grid_thw, image_cu_seqlens, - image_rotary_pos_ids, + image_position_ids, return_dict=True, ) image_embeds = image_outputs.pooler_output @@ -1388,7 +1388,7 @@ def forward( pixel_values_videos, video_grid_thw, video_cu_seqlens, - video_rotary_pos_ids, + video_position_ids, return_dict=True, ) video_embeds = video_outputs.pooler_output @@ -1558,7 +1558,7 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1568,14 +1568,12 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. video_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). - video_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for video rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_video_features( pixel_values_videos, video_grid_thw, video_cu_seqlens, - video_rotary_pos_ids, + video_position_ids, **kwargs, ) @@ -1585,7 +1583,7 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1595,14 +1593,12 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. image_cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). - image_rotary_pos_ids (`torch.LongTensor`, *optional*): - Precomputed (row, col) position IDs for image rotary embeddings (from `get_rotary_pos_ids`). """ return self.model.get_image_features( pixel_values, image_grid_thw, image_cu_seqlens, - image_rotary_pos_ids, + image_position_ids, **kwargs, ) diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index 8028fd2e9f73..fdad9f96a9bc 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -36,7 +36,7 @@ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import is_flash_attention_requested, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from ..auto.modeling_auto import AutoModel from .configuration_video_llama_3 import VideoLlama3Config, VideoLlama3VisionConfig @@ -51,8 +51,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, pos_ids: torch.Tensor) -> torch.Tensor: - return (pos_ids.unsqueeze(-1) * self.inv_freq).flatten(1) + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class VideoLlama3VisionEmbeddings(nn.Module): @@ -415,7 +415,7 @@ def forward( grid_thw: torch.Tensor, merge_sizes: torch.Tensor, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutput: r""" @@ -425,16 +425,16 @@ def forward( The spatial downsampling ratio of each image or video feature. cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + position_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). """ hidden_states = self.embeddings(pixel_values.type(self.dtype)) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, merge_sizes) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, merge_sizes) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -530,7 +530,7 @@ def get_video_features( video_grid_thw: torch.LongTensor, video_merge_sizes: torch.LongTensor, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -546,7 +546,7 @@ def get_video_features( image_grid_thw=video_grid_thw, image_merge_sizes=video_merge_sizes, image_cu_seqlens=video_cu_seqlens, - image_rotary_pos_ids=video_rotary_pos_ids, + image_position_ids=video_position_ids, **kwargs, ) @@ -558,7 +558,7 @@ def get_image_features( image_grid_thw: torch.LongTensor, image_merge_sizes: torch.LongTensor, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -574,7 +574,7 @@ def get_image_features( grid_thw=image_grid_thw, merge_sizes=image_merge_sizes, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, return_dict=True, **kwargs, ) @@ -766,7 +766,7 @@ def get_video_features( video_grid_thw: torch.LongTensor | None = None, video_merge_sizes: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -782,7 +782,7 @@ def get_video_features( video_grid_thw, video_merge_sizes, video_cu_seqlens=video_cu_seqlens, - video_rotary_pos_ids=video_rotary_pos_ids, + video_position_ids=video_position_ids, **kwargs, ) @@ -794,7 +794,7 @@ def get_image_features( image_grid_thw: torch.LongTensor | None = None, image_merge_sizes: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -810,7 +810,7 @@ def get_image_features( image_grid_thw, image_merge_sizes, image_cu_seqlens=image_cu_seqlens, - image_rotary_pos_ids=image_rotary_pos_ids, + image_position_ids=image_position_ids, **kwargs, ) diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index ee3ec500cbdb..32a6ef84e474 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -52,7 +52,7 @@ group_videos_by_shape, reorder_videos, ) -from ...vision_utils import get_rotary_pos_ids, get_vision_cu_seqlens +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from ..auto import CONFIG_MAPPING, AutoConfig from ..auto.modeling_auto import AutoModel from ..qwen2_vl.image_processing_pil_qwen2_vl import Qwen2VLImageProcessorPil @@ -377,7 +377,7 @@ def forward( grid_thw: torch.Tensor, merge_sizes: torch.Tensor, cu_seqlens: torch.Tensor | None = None, - rotary_pos_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutput: r""" @@ -387,16 +387,16 @@ def forward( The spatial downsampling ratio of each image or video feature. cu_seqlens (`torch.IntTensor`, *optional*): Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - rotary_pos_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_rotary_pos_ids`). + position_ids (`torch.Tensor`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). """ hidden_states = self.embeddings(pixel_values.type(self.dtype)) - if rotary_pos_ids is None: - rotary_pos_ids = get_rotary_pos_ids(grid_thw, merge_sizes) + if position_ids is None: + position_ids = get_vision_position_ids(grid_thw, merge_sizes) - rotary_pos_emb = self.rotary_pos_emb(rotary_pos_ids) + rotary_pos_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) @@ -491,7 +491,7 @@ def get_video_features( video_grid_thw: torch.LongTensor, video_merge_sizes: torch.LongTensor, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -507,7 +507,7 @@ def get_video_features( image_grid_thw=video_grid_thw, image_merge_sizes=video_merge_sizes, image_cu_seqlens=video_cu_seqlens, - image_rotary_pos_ids=video_rotary_pos_ids, + image_position_ids=video_position_ids, **kwargs, ) @@ -519,7 +519,7 @@ def get_image_features( image_grid_thw: torch.LongTensor, image_merge_sizes: torch.LongTensor, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -535,7 +535,7 @@ def get_image_features( grid_thw=image_grid_thw, merge_sizes=image_merge_sizes, cu_seqlens=image_cu_seqlens, - rotary_pos_ids=image_rotary_pos_ids, + position_ids=image_position_ids, return_dict=True, **kwargs, ) @@ -675,7 +675,7 @@ def get_image_features( image_grid_thw: torch.LongTensor | None = None, image_merge_sizes: torch.LongTensor | None = None, image_cu_seqlens: torch.Tensor | None = None, - image_rotary_pos_ids: torch.Tensor | None = None, + image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -691,7 +691,7 @@ def get_image_features( image_grid_thw, image_merge_sizes, image_cu_seqlens=image_cu_seqlens, - image_rotary_pos_ids=image_rotary_pos_ids, + image_position_ids=image_position_ids, **kwargs, ) @@ -703,7 +703,7 @@ def get_video_features( video_grid_thw: torch.LongTensor | None = None, video_merge_sizes: torch.LongTensor | None = None, video_cu_seqlens: torch.Tensor | None = None, - video_rotary_pos_ids: torch.Tensor | None = None, + video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -719,7 +719,7 @@ def get_video_features( video_grid_thw, video_merge_sizes, video_cu_seqlens=video_cu_seqlens, - video_rotary_pos_ids=video_rotary_pos_ids, + video_position_ids=video_position_ids, **kwargs, ) diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py index fec3ea6728fc..619de6f0122f 100644 --- a/src/transformers/utils/auto_docstring.py +++ b/src/transformers/utils/auto_docstring.py @@ -2135,14 +2135,14 @@ class ModelArgs: "shape": "of shape `(num_video_patches + 1,)`", } - image_rotary_pos_ids = { + image_position_ids = { "description": """ Precomputed (row, col) position IDs for image rotary embeddings. """, "shape": "of shape `(num_image_tokens, 2)`", } - video_rotary_pos_ids = { + video_position_ids = { "description": """ Precomputed (row, col) position IDs for video rotary embeddings. """, diff --git a/src/transformers/vision_utils.py b/src/transformers/vision_utils.py index 75c154594b29..9a065a4025e4 100644 --- a/src/transformers/vision_utils.py +++ b/src/transformers/vision_utils.py @@ -45,7 +45,7 @@ def get_vision_cu_seqlens(grid_thw: torch.Tensor) -> torch.Tensor: return F.pad(cu_seqlens, (1, 0), value=0) -def get_rotary_pos_ids(grid_thw: torch.Tensor, spatial_merge_size: int | torch.Tensor) -> torch.Tensor: +def get_vision_position_ids(grid_thw: torch.Tensor, spatial_merge_size: int | torch.Tensor) -> torch.Tensor: """Compute (row, col) position IDs for vision rotary embeddings. Args: @@ -54,13 +54,13 @@ def get_rotary_pos_ids(grid_thw: torch.Tensor, spatial_merge_size: int | torch.T or a ``(num_images_or_videos,)`` tensor (per-image). Returns: - ``pos_ids``: ``(total_tokens, 2)`` long — (row, col) position per token. + ``position_ids``: ``(total_tokens, 2)`` long — (row, col) position per token. """ device = grid_thw.device if isinstance(spatial_merge_size, int): spatial_merge_size = torch.tensor([spatial_merge_size], device=device).expand(len(grid_thw)) - pos_ids = [] + position_ids = [] for (t, h, w), m in zip(grid_thw.tolist(), spatial_merge_size.tolist()): t, h, w, m = int(t), int(h), int(w), int(m) hpos_ids = torch.arange(h, device=device).unsqueeze(1).expand(-1, w) @@ -68,12 +68,12 @@ def get_rotary_pos_ids(grid_thw: torch.Tensor, spatial_merge_size: int | torch.T wpos_ids = torch.arange(w, device=device).unsqueeze(0).expand(h, -1) wpos_ids = wpos_ids.reshape(h // m, m, w // m, m).permute(0, 2, 1, 3).flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + position_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - return torch.cat(pos_ids, dim=0) + return torch.cat(position_ids, dim=0) -def get_window_index( +def get_vision_window_index( grid_thw: torch.Tensor, spatial_merge_size: int, window_size: int, @@ -128,7 +128,7 @@ def get_window_index( return window_index, cu_window_seqlens -def get_pos_embed_indices( +def get_vision_bilinear_indices_and_weights( grid_thw: torch.Tensor, num_grid_per_side: int, spatial_merge_size: int ) -> tuple[torch.Tensor, torch.Tensor]: """Compute bilinear interpolation indices and weights for position embeddings. @@ -139,7 +139,7 @@ def get_pos_embed_indices( spatial_merge_size: merge block size from vision config. Returns: - ``embed_indices``: ``(4, total_thw)`` long — bilinear corner indices into pos_embed table. + ``bilinear_indices``: ``(4, total_thw)`` long — bilinear corner indices into pos_embed table. ``bilinear_weights``: ``(4, total_thw)`` float — interpolation weights. """ N = num_grid_per_side @@ -187,6 +187,6 @@ def get_pos_embed_indices( idx_parts[i].append(raw_idx[i][reorder]) weight_parts[i].append(raw_w[i][reorder]) - embed_indices = torch.stack([torch.cat(p) for p in idx_parts]) + bilinear_indices = torch.stack([torch.cat(p) for p in idx_parts]) bilinear_weights = torch.stack([torch.cat(p) for p in weight_parts]) - return embed_indices, bilinear_weights + return bilinear_indices, bilinear_weights From 67b090673b483597e687232907611ef341f0384f Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 21 Apr 2026 12:49:03 +0200 Subject: [PATCH 43/56] move some stuff to their original place --- .../modeling_ernie4_5_vl_moe.py | 20 ++++---- .../models/glm4v/modeling_glm4v.py | 20 ++++---- .../models/glm4v/modular_glm4v.py | 20 ++++---- .../models/glm4v_moe/modeling_glm4v_moe.py | 20 ++++---- .../models/glm_image/modeling_glm_image.py | 16 +++--- .../models/glm_image/modular_glm_image.py | 16 +++--- .../models/glm_ocr/modeling_glm_ocr.py | 20 ++++---- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 50 +++++++++---------- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 50 +++++++++---------- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 50 +++++++++---------- .../models/qwen2_vl/modeling_qwen2_vl.py | 20 ++++---- .../models/qwen3_5/modeling_qwen3_5.py | 42 ++++++++-------- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 42 ++++++++-------- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 42 ++++++++-------- .../models/qwen3_vl/modeling_qwen3_vl.py | 42 ++++++++-------- .../models/qwen3_vl/modular_qwen3_vl.py | 42 ++++++++-------- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 42 ++++++++-------- 17 files changed, 277 insertions(+), 277 deletions(-) diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index a2d3e4350bda..e5a981b474cf 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -893,6 +893,16 @@ def __init__(self, config) -> None: self.post_init() + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb + @merge_with_config_defaults @capture_outputs def forward( @@ -933,16 +943,6 @@ def forward( hidden_states = self.ln(hidden_states) return BaseModelOutputWithPooling(last_hidden_state=hidden_states) - def rot_pos_emb(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", - FutureWarning, - stacklevel=2, - ) - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - return rotary_pos_emb - class Ernie4_5_VLMoeVisionMLP(nn.Module): def __init__(self, config, in_dim, out_dim): diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 77a181c4e54d..5592b22c8790 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -726,6 +726,16 @@ def __init__(self, config) -> None: self.gradient_checkpointing = False self.post_init() + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb, position_ids + @merge_with_config_defaults @capture_outputs @auto_docstring @@ -795,16 +805,6 @@ def forward( pooler_output=merged_hidden_states, ) - def rot_pos_emb(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", - FutureWarning, - stacklevel=2, - ) - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - return rotary_pos_emb, position_ids - @auto_docstring class Glm4vTextModel(Glm4vPreTrainedModel): diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 121f02e6888c..a385c2784fb5 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -610,6 +610,16 @@ def __init__(self, config) -> None: self.gradient_checkpointing = False self.post_init() + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb, position_ids + @merge_with_config_defaults @capture_outputs @auto_docstring @@ -679,16 +689,6 @@ def forward( pooler_output=merged_hidden_states, ) - def rot_pos_emb(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", - FutureWarning, - stacklevel=2, - ) - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - return rotary_pos_emb, position_ids - class Glm4vTextModel(Qwen2_5_VLTextModel): _can_record_outputs = { diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index f7ebbbc0a562..a09f038998c5 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -790,6 +790,16 @@ def __init__(self, config) -> None: self.gradient_checkpointing = False self.post_init() + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb, position_ids + @merge_with_config_defaults @capture_outputs @auto_docstring @@ -859,16 +869,6 @@ def forward( pooler_output=merged_hidden_states, ) - def rot_pos_emb(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", - FutureWarning, - stacklevel=2, - ) - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - return rotary_pos_emb, position_ids - class Glm4vMoeTextRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py index 940e4d5a5f94..5beed94615bb 100644 --- a/src/transformers/models/glm_image/modeling_glm_image.py +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -590,6 +590,14 @@ def __init__(self, config: GlmImageVisionConfig) -> None: self.head_dim = head_dim self.post_init() + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + return get_vision_position_ids(grid_thw, self.spatial_merge_size) + @merge_with_config_defaults @capture_outputs @auto_docstring @@ -640,14 +648,6 @@ def forward( return BaseModelOutputWithPooling(last_hidden_state=hidden_states) - def rot_pos_emb(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", - FutureWarning, - stacklevel=2, - ) - return get_vision_position_ids(grid_thw, self.spatial_merge_size) - @use_kernel_forward_from_hub("RMSNorm") class GlmImageRMSNorm(nn.Module): diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index ee533de4dbc0..b01242f444d3 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -430,6 +430,14 @@ def __init__(self, config: GlmImageVisionConfig): del self.downsample del self.post_layernorm + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + return get_vision_position_ids(grid_thw, self.spatial_merge_size) + @merge_with_config_defaults @capture_outputs @auto_docstring @@ -480,14 +488,6 @@ def forward( return BaseModelOutputWithPooling(last_hidden_state=hidden_states) - def rot_pos_emb(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", - FutureWarning, - stacklevel=2, - ) - return get_vision_position_ids(grid_thw, self.spatial_merge_size) - class GlmImageTextModel(Glm4vTextModel): pass diff --git a/src/transformers/models/glm_ocr/modeling_glm_ocr.py b/src/transformers/models/glm_ocr/modeling_glm_ocr.py index 5925675fd7bc..d7429a6c32eb 100644 --- a/src/transformers/models/glm_ocr/modeling_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modeling_glm_ocr.py @@ -578,6 +578,16 @@ def __init__(self, config) -> None: self.gradient_checkpointing = False self.post_init() + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb, position_ids + @merge_with_config_defaults @capture_outputs @auto_docstring @@ -635,16 +645,6 @@ def forward( pooler_output=merged_hidden_states, ) - def rot_pos_emb(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", - FutureWarning, - stacklevel=2, - ) - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - return rotary_pos_emb, position_ids - class GlmOcrTextRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 39dff932bff2..b3989ed62c91 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1248,6 +1248,31 @@ def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> self.post_init() + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.get_window_index` is deprecated and will be removed in a future version. Use `get_vision_window_index` from `transformers.vision_utils` instead.", + FutureWarning, + stacklevel=2, + ) + window_index, cu_window_seqlens = get_vision_window_index( + grid_thw, + self.spatial_merge_size, + self.window_size, + self.patch_size, + self.spatial_merge_unit, + ) + return window_index, cu_window_seqlens.tolist() + @merge_with_config_defaults @capture_outputs def forward( @@ -1324,31 +1349,6 @@ def forward( pooler_output=merged_hidden_states, ) - def rot_pos_emb(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", - FutureWarning, - stacklevel=2, - ) - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - return rotary_pos_emb - - def get_vision_window_index(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.get_vision_window_index` is deprecated and will be removed in a future version. Use `get_vision_window_index` from `transformers.vision_utils` instead.", - FutureWarning, - stacklevel=2, - ) - window_index, cu_window_seqlens = get_vision_window_index( - grid_thw, - self.spatial_merge_size, - self.window_size, - self.patch_size, - self.spatial_merge_unit, - ) - return window_index, cu_window_seqlens.tolist() - class Qwen2_5OmniRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index f93e3ffa1061..0525b599dd26 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -378,6 +378,31 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.get_window_index` is deprecated and will be removed in a future version. Use `get_vision_window_index` from `transformers.vision_utils` instead.", + FutureWarning, + stacklevel=2, + ) + window_index, cu_window_seqlens = get_vision_window_index( + grid_thw, + self.spatial_merge_size, + self.window_size, + self.patch_size, + self.spatial_merge_unit, + ) + return window_index, cu_window_seqlens.tolist() + @merge_with_config_defaults @capture_outputs def forward( @@ -456,31 +481,6 @@ def forward( pooler_output=merged_hidden_states, ) - def rot_pos_emb(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", - FutureWarning, - stacklevel=2, - ) - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - return rotary_pos_emb - - def get_vision_window_index(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.get_vision_window_index` is deprecated and will be removed in a future version. Use `get_vision_window_index` from `transformers.vision_utils` instead.", - FutureWarning, - stacklevel=2, - ) - window_index, cu_window_seqlens = get_vision_window_index( - grid_thw, - self.spatial_merge_size, - self.window_size, - self.patch_size, - self.spatial_merge_unit, - ) - return window_index, cu_window_seqlens.tolist() - @dataclass @auto_docstring( diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 1b38a12d8e30..38a9e45bfdae 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -219,6 +219,31 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.get_window_index` is deprecated and will be removed in a future version. Use `get_vision_window_index` from `transformers.vision_utils` instead.", + FutureWarning, + stacklevel=2, + ) + window_index, cu_window_seqlens = get_vision_window_index( + grid_thw, + self.spatial_merge_size, + self.window_size, + self.patch_size, + self.spatial_merge_unit, + ) + return window_index, cu_window_seqlens.tolist() + @merge_with_config_defaults @capture_outputs def forward( @@ -297,31 +322,6 @@ def forward( pooler_output=merged_hidden_states, ) - def rot_pos_emb(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", - FutureWarning, - stacklevel=2, - ) - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - return rotary_pos_emb - - def get_vision_window_index(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.get_vision_window_index` is deprecated and will be removed in a future version. Use `get_vision_window_index` from `transformers.vision_utils` instead.", - FutureWarning, - stacklevel=2, - ) - window_index, cu_window_seqlens = get_vision_window_index( - grid_thw, - self.spatial_merge_size, - self.window_size, - self.patch_size, - self.spatial_merge_unit, - ) - return window_index, cu_window_seqlens.tolist() - class Qwen2_5_VLModelOutputWithPast(Qwen2VLModelOutputWithPast): pass diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index d476e452a8b7..a2106844e76e 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -721,6 +721,16 @@ def get_dtype(self) -> torch.dtype: def get_device(self) -> torch.device: return self.blocks[0].mlp.fc2.weight.device + def rot_pos_emb(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb + @merge_with_config_defaults @capture_outputs @auto_docstring @@ -768,16 +778,6 @@ def forward( pooler_output=merged_hidden_states, ) - def rot_pos_emb(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", - FutureWarning, - stacklevel=2, - ) - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - return rotary_pos_emb - @auto_docstring class Qwen2VLTextModel(Qwen2VLPreTrainedModel): diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index d506a024edca..fce209f5edd7 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1027,6 +1027,27 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", + FutureWarning, + stacklevel=2, + ) + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) + @merge_with_config_defaults @capture_outputs def forward( @@ -1095,27 +1116,6 @@ def forward( pooler_output=merged_hidden_states, ) - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", - FutureWarning, - stacklevel=2, - ) - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - return rotary_pos_emb - - def fast_pos_embed_interpolate(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", - FutureWarning, - stacklevel=2, - ) - bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( - grid_thw, self.num_grid_per_side, self.config.spatial_merge_size - ) - return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) - @dataclass @auto_docstring( diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index e8d78e7cb0a0..03d2b4db8ce0 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1120,6 +1120,27 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", + FutureWarning, + stacklevel=2, + ) + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) + @merge_with_config_defaults @capture_outputs def forward( @@ -1188,27 +1209,6 @@ def forward( pooler_output=merged_hidden_states, ) - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", - FutureWarning, - stacklevel=2, - ) - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - return rotary_pos_emb - - def fast_pos_embed_interpolate(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", - FutureWarning, - stacklevel=2, - ) - bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( - grid_thw, self.num_grid_per_side, self.config.spatial_merge_size - ) - return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) - @dataclass @auto_docstring( diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index eb1f261b5964..4c5b9e44e8ee 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1167,6 +1167,27 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", + FutureWarning, + stacklevel=2, + ) + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) + @merge_with_config_defaults @capture_outputs def forward( @@ -1242,27 +1263,6 @@ def forward( deepstack_features=deepstack_feature_lists, ) - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", - FutureWarning, - stacklevel=2, - ) - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - return rotary_pos_emb - - def fast_pos_embed_interpolate(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", - FutureWarning, - stacklevel=2, - ) - bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( - grid_thw, self.num_grid_per_side, self.config.spatial_merge_size - ) - return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) - @property def deepstack_merger_list(self): return self.merger_list diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index b6beaf0cc6a7..f2a4182f6a8b 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -658,6 +658,27 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", + FutureWarning, + stacklevel=2, + ) + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) + @merge_with_config_defaults @capture_outputs def forward( @@ -733,27 +754,6 @@ def forward( deepstack_features=deepstack_feature_lists, ) - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", - FutureWarning, - stacklevel=2, - ) - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - return rotary_pos_emb - - def fast_pos_embed_interpolate(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", - FutureWarning, - stacklevel=2, - ) - bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( - grid_thw, self.num_grid_per_side, self.config.spatial_merge_size - ) - return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) - @auto_docstring( custom_intro=( diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 10f6c2778bb9..ac2fc568a81a 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -446,6 +446,27 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", + FutureWarning, + stacklevel=2, + ) + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) + @merge_with_config_defaults @capture_outputs def forward( @@ -521,27 +542,6 @@ def forward( deepstack_features=deepstack_feature_lists, ) - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", - FutureWarning, - stacklevel=2, - ) - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - return rotary_pos_emb - - def fast_pos_embed_interpolate(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", - FutureWarning, - stacklevel=2, - ) - bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( - grid_thw, self.num_grid_per_side, self.config.spatial_merge_size - ) - return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) - @auto_docstring( custom_intro=( diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 834e3cec18b7..43df218ea2bc 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -643,6 +643,27 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): + warnings.warn( + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", + FutureWarning, + stacklevel=2, + ) + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) + @merge_with_config_defaults @capture_outputs def forward( @@ -718,27 +739,6 @@ def forward( deepstack_features=deepstack_feature_lists, ) - def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - warnings.warn( - f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", - FutureWarning, - stacklevel=2, - ) - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - return rotary_pos_emb - - def fast_pos_embed_interpolate(self, grid_thw): - warnings.warn( - f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", - FutureWarning, - stacklevel=2, - ) - bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( - grid_thw, self.num_grid_per_side, self.config.spatial_merge_size - ) - return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) - class Qwen3VLMoeTextRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` From b8323fb8846fe6be235db90372f5f3163565ed3f Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 21 Apr 2026 13:22:29 +0200 Subject: [PATCH 44/56] style --- .../models/glm_image/modeling_glm_image.py | 4 ++-- .../models/glm_image/modular_glm_image.py | 4 ++-- .../models/paddleocr_vl/modeling_paddleocr_vl.py | 11 +++++------ .../models/paddleocr_vl/modular_paddleocr_vl.py | 11 +++++------ 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py index 5beed94615bb..4c30e7a138f3 100644 --- a/src/transformers/models/glm_image/modeling_glm_image.py +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -1164,8 +1164,8 @@ def get_rope_index( if all_decode_position_ids: max_decode_len = max(x.shape[1] for x in all_decode_position_ids) padded_decode_pos_ids = [ - F.pad(position_ids, (0, max_decode_len - position_ids.shape[1]), mode="replicate") - for position_ids in all_decode_position_ids + F.pad(pos_ids, (0, max_decode_len - pos_ids.shape[1]), mode="replicate") + for pos_ids in all_decode_position_ids ] self._cached_decode_position_ids = torch.stack(padded_decode_pos_ids, dim=0) # [batch, 3, max_decode_len] else: diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index b01242f444d3..253c2f04821d 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -667,8 +667,8 @@ def get_rope_index( if all_decode_position_ids: max_decode_len = max(x.shape[1] for x in all_decode_position_ids) padded_decode_pos_ids = [ - F.pad(position_ids, (0, max_decode_len - position_ids.shape[1]), mode="replicate") - for position_ids in all_decode_position_ids + F.pad(pos_ids, (0, max_decode_len - pos_ids.shape[1]), mode="replicate") + for pos_ids in all_decode_position_ids ] self._cached_decode_position_ids = torch.stack(padded_decode_pos_ids, dim=0) # [batch, 3, max_decode_len] else: diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 716fa78b15dd..88dcc1e12272 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -820,8 +820,8 @@ def forward( self, inputs_embeds: torch.FloatTensor, grid_thw: torch.LongTensor | None = None, - cu_seqlens: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: @@ -832,13 +832,12 @@ def forward( than the model's internal embedding lookup matrix. grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`, *optional*): - The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - position_ids (`torch.Tensor` of shape `(sequence_length, 2)`, *optional*): - Precomputed rotary position ids as `(row, column)` pairs. If not provided, will be computed based on - `image_grid_thw`. + cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). + position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). """ hidden_states = inputs_embeds attention_mask = create_bidirectional_mask( diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index 76c22a6fe116..918739701549 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -788,8 +788,8 @@ def forward( self, inputs_embeds: torch.FloatTensor, grid_thw: torch.LongTensor | None = None, - cu_seqlens: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: @@ -800,13 +800,12 @@ def forward( than the model's internal embedding lookup matrix. grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`, *optional*): - The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - position_ids (`torch.Tensor` of shape `(sequence_length, 2)`, *optional*): - Precomputed rotary position ids as `(row, column)` pairs. If not provided, will be computed based on - `image_grid_thw`. + cu_seqlens (`torch.IntTensor`, *optional*): + Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). + position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): + Precomputed (row, col) position IDs (from `get_vision_position_ids`). """ hidden_states = inputs_embeds attention_mask = create_bidirectional_mask( From a6b071f0fbcf0c5bf8e50fe42b6ba6c479574f81 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 21 Apr 2026 15:46:58 +0200 Subject: [PATCH 45/56] style --- .../models/qwen2_5_omni/modeling_qwen2_5_omni.py | 4 ++-- .../models/qwen2_5_omni/modular_qwen2_5_omni.py | 4 ++-- src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 7 +++---- src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py | 7 +++---- src/transformers/models/zamba2/configuration_zamba2.py | 4 ++-- 5 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index b3989ed62c91..13e081c01556 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1308,8 +1308,6 @@ def forward( if position_ids is None: position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -1323,6 +1321,8 @@ def forward( hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) + + rotary_pos_emb = self.rotary_pos_emb(position_ids) rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 70df2a44650a..7a36096c7d12 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1638,8 +1638,6 @@ def forward( if position_ids is None: position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -1653,6 +1651,8 @@ def forward( hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) + + rotary_pos_emb = self.rotary_pos_emb(position_ids) rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 0525b599dd26..a226dd043193 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -438,8 +438,6 @@ def forward( if position_ids is None: position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -448,12 +446,12 @@ def forward( grid_thw, self.spatial_merge_size, self.window_size, self.patch_size, self.spatial_merge_unit ) - reverse_indices = torch.argsort(window_index) - seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) + + rotary_pos_emb = self.rotary_pos_emb(position_ids) rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) @@ -473,6 +471,7 @@ def forward( **kwargs, ) + reverse_indices = torch.argsort(window_index) merged_hidden_states = self.merger(hidden_states) merged_hidden_states = merged_hidden_states[reverse_indices, :] diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index ba247154df76..f5911989aacd 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -279,8 +279,6 @@ def forward( if position_ids is None: position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - if cu_seqlens is None: cu_seqlens = get_vision_cu_seqlens(grid_thw) @@ -289,12 +287,12 @@ def forward( grid_thw, self.spatial_merge_size, self.window_size, self.patch_size, self.spatial_merge_unit ) - reverse_indices = torch.argsort(window_index) - seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) + + rotary_pos_emb = self.rotary_pos_emb(position_ids) rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) @@ -314,6 +312,7 @@ def forward( **kwargs, ) + reverse_indices = torch.argsort(window_index) merged_hidden_states = self.merger(hidden_states) merged_hidden_states = merged_hidden_states[reverse_indices, :] diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index f64cd5968552..6d754d1ecf7a 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -29,12 +29,12 @@ class Zamba2Config(PreTrainedConfig): Number of groups for the evolution matrices of mamba 2. n_mamba_heads (`int`, *optional*, defaults to 8): Number of heads for the evolution matrices of mamba 2. + use_mamba_kernels (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use the fast mamba kernels. use_conv_bias (`bool`, *optional*, defaults to `True`): Whether or not to use bias in the convolution layer of the mixer block. chunk_size (`int`, *optional*, defaults to 256): Size of the chunks that will comprise the sequence. - use_mamba_kernels (`bool`, *optional*, defaults to `True`): - Flag indicating whether or not to use the fast mamba kernels. use_mem_eff_path (`bool`, *optional*, defaults to `False`): Whether or not to use the fused conv1d and scan in mamba2 layers. add_bias_linear (`bool`, *optional*, defaults to `False`): From e9ac058e4dc149b4d3764e850ca8d657530e50dd Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 22 Apr 2026 16:08:02 +0200 Subject: [PATCH 46/56] use chunked attention --- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 81 +++++++++--------- .../qwen2_5_omni/modular_qwen2_5_omni.py | 81 +++++++++--------- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 85 +++++++++---------- .../qwen3_omni_moe/modular_qwen3_omni_moe.py | 25 +----- src/transformers/vision_utils.py | 8 +- 5 files changed, 126 insertions(+), 154 deletions(-) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 13e081c01556..f3559d031547 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -603,10 +603,9 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, + cu_seqlens: torch.Tensor, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + ) -> torch.Tensor: """Input shape: Batch x Time x Channel""" seq_length, _ = hidden_states.size() @@ -618,27 +617,50 @@ def forward( query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seq_lens_k=cu_seqlens, - max_length_q=max_seqlen, - max_length_k=max_seqlen, - is_causal=False, - **kwargs, - ) + if is_flash_attention_requested(self.config): + # Flash Attention: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.out_proj(attn_output) @@ -663,7 +685,6 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - attention_mask: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: """ @@ -677,7 +698,6 @@ def forward( hidden_states = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, - attention_mask=attention_mask, **kwargs, ) hidden_states = residual + hidden_states @@ -900,29 +920,10 @@ def forward( if cu_seqlens is None: cu_seqlens = get_audio_cu_seqlens(chunk_lengths) - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention mask only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if is_flash_attention_requested(self.config): - attention_mask = None - else: - seq_idx = torch.arange(hidden_states.shape[0], device=hidden_states.device) - block_ids = torch.searchsorted(cu_seqlens[1:], seq_idx, right=True) - same_block = block_ids.unsqueeze(0) == block_ids.unsqueeze(1) - attention_mask = torch.full( - (hidden_states.shape[0], hidden_states.shape[0]), - torch.finfo(hidden_states.dtype).min, - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - attention_mask = attention_mask.masked_fill(same_block, 0.0).unsqueeze(0).unsqueeze(0) - for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, cu_seqlens=cu_seqlens, - attention_mask=attention_mask, **kwargs, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 7a36096c7d12..7aecdce960f1 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1175,10 +1175,9 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, + cu_seqlens: torch.Tensor, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + ) -> torch.Tensor: """Input shape: Batch x Time x Channel""" seq_length, _ = hidden_states.size() @@ -1190,27 +1189,50 @@ def forward( query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seq_lens_k=cu_seqlens, - max_length_q=max_seqlen, - max_length_k=max_seqlen, - is_causal=False, - **kwargs, - ) + if is_flash_attention_requested(self.config): + # Flash Attention: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.out_proj(attn_output) @@ -1227,7 +1249,6 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - attention_mask: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: residual = hidden_states @@ -1235,7 +1256,6 @@ def forward( hidden_states = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, - attention_mask=attention_mask, **kwargs, ) hidden_states = residual + hidden_states @@ -1382,29 +1402,10 @@ def forward( if cu_seqlens is None: cu_seqlens = get_audio_cu_seqlens(chunk_lengths) - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention mask only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if is_flash_attention_requested(self.config): - attention_mask = None - else: - seq_idx = torch.arange(hidden_states.shape[0], device=hidden_states.device) - block_ids = torch.searchsorted(cu_seqlens[1:], seq_idx, right=True) - same_block = block_ids.unsqueeze(0) == block_ids.unsqueeze(1) - attention_mask = torch.full( - (hidden_states.shape[0], hidden_states.shape[0]), - torch.finfo(hidden_states.dtype).min, - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - attention_mask = attention_mask.masked_fill(same_block, 0.0).unsqueeze(0).unsqueeze(0) - for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, cu_seqlens=cu_seqlens, - attention_mask=attention_mask, **kwargs, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 4c5b9e44e8ee..5c6a03753d68 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -524,10 +524,9 @@ def __init__(self, config): def forward( self, hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, + cu_seqlens: torch.Tensor, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + ) -> torch.Tensor: """Input shape: Batch x Time x Channel""" seq_length, _ = hidden_states.size() @@ -539,27 +538,50 @@ def forward( query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seq_lens_k=cu_seqlens, - max_length_q=max_seqlen, - max_length_k=max_seqlen, - is_causal=False, - **kwargs, - ) + if is_flash_attention_requested(self.config): + # Flash Attention: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.out_proj(attn_output) @@ -584,7 +606,6 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - attention_mask: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: """ @@ -598,7 +619,6 @@ def forward( hidden_states = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, - attention_mask=attention_mask, **kwargs, ) hidden_states = residual + hidden_states @@ -804,33 +824,10 @@ def forward( padded_embed = padded_embed + positional_embedding hidden_states = torch.index_select(padded_embed.reshape(-1, padded_embed.shape[-1]), 0, valid_indices) - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention mask only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if is_flash_attention_requested(self.config): - attention_mask = None - else: - seq_idx = torch.arange(hidden_states.shape[0], device=hidden_states.device) - block_ids = torch.searchsorted(cu_seqlens[1:], seq_idx, right=True) - same_block = block_ids.unsqueeze(0) == block_ids.unsqueeze(1) - attention_mask = ( - torch.full( - (hidden_states.shape[0], hidden_states.shape[0]), - torch.finfo(hidden_states.dtype).min, - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - .masked_fill(same_block, 0.0) - .unsqueeze(0) - .unsqueeze(0) - ) - for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, cu_seqlens, - attention_mask=attention_mask, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index 57280d6465bb..195a570a2bc2 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -46,7 +46,7 @@ from ...processing_utils import ProcessorMixin, Unpack from ...tokenization_utils_base import TextInput from ...utils import auto_docstring, can_return_tuple, logging -from ...utils.generic import TransformersKwargs, is_flash_attention_requested, merge_with_config_defaults +from ...utils.generic import TransformersKwargs, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs from ...video_utils import VideoInput, make_batched_videos from ..mimi.modeling_mimi import MimiLayerScale @@ -1033,33 +1033,10 @@ def forward( padded_embed = padded_embed + positional_embedding hidden_states = torch.index_select(padded_embed.reshape(-1, padded_embed.shape[-1]), 0, valid_indices) - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention mask only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if is_flash_attention_requested(self.config): - attention_mask = None - else: - seq_idx = torch.arange(hidden_states.shape[0], device=hidden_states.device) - block_ids = torch.searchsorted(cu_seqlens[1:], seq_idx, right=True) - same_block = block_ids.unsqueeze(0) == block_ids.unsqueeze(1) - attention_mask = ( - torch.full( - (hidden_states.shape[0], hidden_states.shape[0]), - torch.finfo(hidden_states.dtype).min, - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - .masked_fill(same_block, 0.0) - .unsqueeze(0) - .unsqueeze(0) - ) - for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, cu_seqlens, - attention_mask=attention_mask, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/vision_utils.py b/src/transformers/vision_utils.py index 9a065a4025e4..18d9c6b03da4 100644 --- a/src/transformers/vision_utils.py +++ b/src/transformers/vision_utils.py @@ -22,12 +22,8 @@ from __future__ import annotations -from .utils.import_utils import is_torch_available - - -if is_torch_available(): - import torch - import torch.nn.functional as F +import torch +import torch.nn.functional as F def get_vision_cu_seqlens(grid_thw: torch.Tensor) -> torch.Tensor: From a7c22775d92cbde10d1512523f7376cffd535d1b Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 27 Apr 2026 23:09:10 +0200 Subject: [PATCH 47/56] use decorator --- .../modeling_ernie4_5_vl_moe.py | 55 +++---------- .../modular_ernie4_5_vl_moe.py | 26 ++---- .../models/glm46v/modeling_glm46v.py | 50 +++--------- .../models/glm4v/modeling_glm4v.py | 56 ++++--------- .../models/glm4v/modular_glm4v.py | 14 +--- .../models/glm4v_moe/modeling_glm4v_moe.py | 57 ++++--------- .../models/glm_image/modeling_glm_image.py | 14 +--- .../models/glm_image/modular_glm_image.py | 14 +--- .../models/glm_ocr/modeling_glm_ocr.py | 56 ++++--------- .../paddleocr_vl/modeling_paddleocr_vl.py | 32 +++----- .../paddleocr_vl/modular_paddleocr_vl.py | 14 +--- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 31 +++----- .../qwen2_5_omni/modular_qwen2_5_omni.py | 26 ++---- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 73 ++++------------- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 18 +---- .../models/qwen2_vl/modeling_qwen2_vl.py | 79 +++---------------- .../models/qwen3_5/modeling_qwen3_5.py | 56 +++---------- .../models/qwen3_5/modular_qwen3_5.py | 14 +--- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 56 +++---------- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 25 ++---- .../qwen3_omni_moe/modular_qwen3_omni_moe.py | 26 ++---- .../models/qwen3_vl/modeling_qwen3_vl.py | 68 ++++------------ .../models/qwen3_vl/modular_qwen3_vl.py | 39 ++------- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 68 ++++------------ .../video_llama_3/modeling_video_llama_3.py | 42 ++-------- .../video_llama_3/modular_video_llama_3.py | 42 ++-------- src/transformers/utils/auto_docstring.py | 28 ------- src/transformers/utils/generic.py | 39 +++++++++ 28 files changed, 266 insertions(+), 852 deletions(-) diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index e5a981b474cf..1eea462ceeca 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -40,7 +40,12 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.output_capturing import OutputRecorder, capture_outputs from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_ernie4_5_vl_moe import Ernie4_5_VLMoeConfig, Ernie4_5_VLMoeTextConfig, Ernie4_5_VLMoeVisionConfig @@ -1250,14 +1255,13 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1266,14 +1270,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - video_outputs = self.vision_tower( - pixel_values_videos, - video_grid_thw, - cu_seqlens=video_cu_seqlens, - position_ids=video_position_ids, - return_dict=True, - **kwargs, - ) + video_outputs = self.vision_tower(pixel_values_videos, video_grid_thw, **kwargs) video_embeds = self.resampler_model(video_outputs.last_hidden_state, video_grid_thw) split_sizes = ( video_grid_thw.prod(-1) @@ -1284,14 +1281,13 @@ def get_video_features( video_outputs.pooler_output = video_embeds return video_outputs + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1300,14 +1296,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_outputs = self.vision_tower( - pixel_values, - image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - return_dict=True, - **kwargs, - ) + image_outputs = self.vision_tower(pixel_values, image_grid_thw, **kwargs) image_embeds = self.resampler_model(image_outputs.last_hidden_state, image_grid_thw) split_sizes = (image_grid_thw.prod(-1) // self.vision_tower.spatial_merge_size**2).tolist() image_embeds = torch.split(image_embeds, split_sizes) @@ -1595,8 +1584,6 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1604,24 +1591,14 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). """ - return self.model.get_video_features( - pixel_values_videos, - video_grid_thw, - video_cu_seqlens, - video_position_ids, - **kwargs, - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1629,16 +1606,8 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). """ - return self.model.get_image_features( - pixel_values, - image_grid_thw, - image_cu_seqlens, - image_position_ids, - **kwargs, - ) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @auto_docstring @can_return_tuple diff --git a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py index 8664cdd7361a..d90b601749c8 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py @@ -50,7 +50,7 @@ can_return_tuple, logging, ) -from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from ..ernie4_5_moe.configuration_ernie4_5_moe import Ernie4_5_MoeConfig @@ -964,24 +964,16 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: - video_outputs = self.vision_tower( - pixel_values_videos, - video_grid_thw, - cu_seqlens=video_cu_seqlens, - position_ids=video_position_ids, - return_dict=True, - **kwargs, - ) + video_outputs = self.vision_tower(pixel_values_videos, video_grid_thw, **kwargs) video_embeds = self.resampler_model(video_outputs.last_hidden_state, video_grid_thw) split_sizes = ( video_grid_thw.prod(-1) @@ -992,24 +984,16 @@ def get_video_features( video_outputs.pooler_output = video_embeds return video_outputs + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: - image_outputs = self.vision_tower( - pixel_values, - image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - return_dict=True, - **kwargs, - ) + image_outputs = self.vision_tower(pixel_values, image_grid_thw, **kwargs) image_embeds = self.resampler_model(image_outputs.last_hidden_state, image_grid_thw) split_sizes = (image_grid_thw.prod(-1) // self.vision_tower.spatial_merge_size**2).tolist() image_embeds = torch.split(image_embeds, split_sizes) diff --git a/src/transformers/models/glm46v/modeling_glm46v.py b/src/transformers/models/glm46v/modeling_glm46v.py index 4505927a7254..3f5f2dd4cd12 100644 --- a/src/transformers/models/glm46v/modeling_glm46v.py +++ b/src/transformers/models/glm46v/modeling_glm46v.py @@ -37,6 +37,7 @@ can_return_tuple, torch_compilable_check, ) +from ...utils.generic import handle_extra_kwargs from ..auto import AutoModel from .configuration_glm46v import Glm46VConfig @@ -254,14 +255,13 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -278,28 +278,20 @@ def get_video_features( flattened_hw = torch.repeat_interleave(hw, t, dim=0) prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) - vision_outputs = self.visual( - pixel_values_videos, - grid_thw=flattened_video_grid_thw, - cu_seqlens=video_cu_seqlens, - position_ids=video_position_ids, - return_dict=True, - **kwargs, - ) + vision_outputs = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw, **kwargs) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds return vision_outputs + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -309,13 +301,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual( - pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - **kwargs, - ) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds @@ -529,13 +515,12 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -543,24 +528,15 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). """ - return self.model.get_video_features( - pixel_values_videos, - video_grid_thw, - video_cu_seqlens, - video_position_ids, - **kwargs, - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -568,16 +544,8 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). """ - return self.model.get_image_features( - pixel_values, - image_grid_thw, - image_cu_seqlens, - image_position_ids, - **kwargs, - ) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 5592b22c8790..db57222e5c6d 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -41,7 +41,12 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.output_capturing import capture_outputs from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_glm4v import Glm4vConfig, Glm4vTextConfig, Glm4vVisionConfig @@ -1084,14 +1089,13 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1108,28 +1112,20 @@ def get_video_features( flattened_hw = torch.repeat_interleave(hw, t, dim=0) prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) - vision_outputs = self.visual( - pixel_values_videos, - grid_thw=flattened_video_grid_thw, - cu_seqlens=video_cu_seqlens, - position_ids=video_position_ids, - return_dict=True, - **kwargs, - ) + vision_outputs = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw, **kwargs) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds return vision_outputs + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1139,13 +1135,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual( - pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - **kwargs, - ) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds @@ -1359,13 +1349,12 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1373,24 +1362,15 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). """ - return self.model.get_video_features( - pixel_values_videos, - video_grid_thw, - video_cu_seqlens, - video_position_ids, - **kwargs, - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1398,16 +1378,8 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). """ - return self.model.get_image_features( - pixel_values, - image_grid_thw, - image_cu_seqlens, - image_position_ids, - **kwargs, - ) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index a385c2784fb5..0e1aa1927c3e 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -42,7 +42,7 @@ logging, torch_compilable_check, ) -from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ...video_utils import VideoInput from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids @@ -793,14 +793,13 @@ def __init__(self, config): super().__init__(config) self.visual = Glm4vVisionModel._from_config(config.vision_config) + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -817,14 +816,7 @@ def get_video_features( flattened_hw = torch.repeat_interleave(hw, t, dim=0) prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) - vision_outputs = self.visual( - pixel_values_videos, - grid_thw=flattened_video_grid_thw, - cu_seqlens=video_cu_seqlens, - position_ids=video_position_ids, - return_dict=True, - **kwargs, - ) + vision_outputs = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw, **kwargs) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 47cb47375a06..3428c6d5c414 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -41,7 +41,13 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, torch_compilable_check -from ...utils.generic import can_return_tuple, is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + can_return_tuple, + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.output_capturing import capture_outputs from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_glm4v_moe import Glm4vMoeConfig, Glm4vMoeTextConfig, Glm4vMoeVisionConfig @@ -1253,14 +1259,13 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1277,28 +1282,20 @@ def get_video_features( flattened_hw = torch.repeat_interleave(hw, t, dim=0) prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) - vision_outputs = self.visual( - pixel_values_videos, - grid_thw=flattened_video_grid_thw, - cu_seqlens=video_cu_seqlens, - position_ids=video_position_ids, - return_dict=True, - **kwargs, - ) + vision_outputs = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw, **kwargs) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds return vision_outputs + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1308,13 +1305,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual( - pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - **kwargs, - ) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds @@ -1583,13 +1574,12 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1597,24 +1587,15 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). """ - return self.model.get_video_features( - pixel_values_videos, - video_grid_thw, - video_cu_seqlens, - video_position_ids, - **kwargs, - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1622,16 +1603,8 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). """ - return self.model.get_image_features( - pixel_values, - image_grid_thw, - image_cu_seqlens, - image_position_ids, - **kwargs, - ) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @auto_docstring @can_return_tuple diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py index 4c30e7a138f3..56a583dca040 100644 --- a/src/transformers/models/glm_image/modeling_glm_image.py +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -39,7 +39,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_glm_image import GlmImageConfig, GlmImageTextConfig, GlmImageVisionConfig, GlmImageVQVAEConfig @@ -1175,14 +1175,13 @@ def get_rope_index( return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1192,14 +1191,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual( - pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - return_dict=True, - **kwargs, - ) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.last_hidden_state, split_sizes) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index 253c2f04821d..c6ae68ad196c 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -32,7 +32,7 @@ from ...processing_utils import ImagesKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, merge_with_config_defaults from ...utils.import_utils import requires from ...utils.output_capturing import capture_outputs from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids @@ -712,14 +712,13 @@ def get_image_tokens( def get_video_features(self): raise AttributeError("Not needed for GlmImage") + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -729,14 +728,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual( - pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - return_dict=True, - **kwargs, - ) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.last_hidden_state, split_sizes) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/glm_ocr/modeling_glm_ocr.py b/src/transformers/models/glm_ocr/modeling_glm_ocr.py index d7429a6c32eb..a8f1fdb1eef3 100644 --- a/src/transformers/models/glm_ocr/modeling_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modeling_glm_ocr.py @@ -41,7 +41,12 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.output_capturing import capture_outputs from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_glm_ocr import GlmOcrConfig, GlmOcrTextConfig, GlmOcrVisionConfig @@ -1001,14 +1006,13 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1025,28 +1029,20 @@ def get_video_features( flattened_hw = torch.repeat_interleave(hw, t, dim=0) prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) - vision_outputs = self.visual( - pixel_values_videos, - grid_thw=flattened_video_grid_thw, - cu_seqlens=video_cu_seqlens, - position_ids=video_position_ids, - return_dict=True, - **kwargs, - ) + vision_outputs = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw, **kwargs) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds return vision_outputs + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1056,13 +1052,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual( - pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - **kwargs, - ) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds @@ -1276,13 +1266,12 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1290,24 +1279,15 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). """ - return self.model.get_video_features( - pixel_values_videos, - video_grid_thw, - video_cu_seqlens, - video_position_ids, - **kwargs, - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1315,16 +1295,8 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). """ - return self.model.get_image_features( - pixel_values, - image_grid_thw, - image_cu_seqlens, - image_position_ids, - **kwargs, - ) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 88dcc1e12272..d332278336c9 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -44,7 +44,12 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check, torch_int -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.output_capturing import capture_outputs from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_paddleocr_vl import PaddleOCRTextConfig, PaddleOCRVisionConfig, PaddleOCRVLConfig @@ -1209,14 +1214,13 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1226,14 +1230,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype).unsqueeze(0) - vision_outputs = self.visual( - pixel_values=pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - return_dict=True, - **kwargs, - ) + vision_outputs = self.visual(pixel_values=pixel_values, grid_thw=image_grid_thw, **kwargs) image_embeds = vision_outputs.last_hidden_state image_embeds = self.projector(image_embeds, image_grid_thw) vision_outputs.pooler_output = image_embeds @@ -1391,13 +1388,12 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="image") @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1405,16 +1401,8 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). """ - return self.model.get_image_features( - pixel_values, - image_grid_thw, - image_cu_seqlens, - image_position_ids, - **kwargs, - ) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index 0b7938009260..487f73da4097 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -54,7 +54,7 @@ torch_compilable_check, torch_int, ) -from ...utils.generic import merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from ..ernie4_5.configuration_ernie4_5 import Ernie4_5Config @@ -961,14 +961,13 @@ def set_input_embeddings(self, value): def get_video_features(self): raise AttributeError("PaddleOCRVLModel does not support video.") + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -978,14 +977,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype).unsqueeze(0) - vision_outputs = self.visual( - pixel_values=pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - return_dict=True, - **kwargs, - ) + vision_outputs = self.visual(pixel_values=pixel_values, grid_thw=image_grid_thw, **kwargs) image_embeds = vision_outputs.last_hidden_state image_embeds = self.projector(image_embeds, image_grid_thw) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index f3559d031547..9e0cd73c1497 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -52,7 +52,12 @@ torch_compilable_check, ) from ...utils.deprecation import deprecate_kwarg -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.hub import cached_file from ...utils.output_capturing import capture_outputs from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids, get_vision_window_index @@ -1773,14 +1778,13 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1790,22 +1794,15 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual( - pixel_values_videos, - grid_thw=video_grid_thw, - cu_seqlens=video_cu_seqlens, - position_ids=video_position_ids, - **kwargs, - ) + return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1815,13 +1812,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - return self.visual( - pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - **kwargs, - ) + return self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) @can_return_tuple @auto_docstring @@ -1851,7 +1842,7 @@ def get_audio_features( ) feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) audio_outputs = self.audio_tower( - input_features, feature_lens=feature_lens, aftercnn_lens=audio_feat_lengths, return_dict=True, **kwargs + input_features, feature_lens=feature_lens, aftercnn_lens=audio_feat_lengths, **kwargs ) if audio_outputs.last_hidden_state.shape[0] != sum(audio_output_lengths.tolist()): raise ValueError("length of audio_features should match audio_output_lengths") diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 7aecdce960f1..eb6254589d7e 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -44,7 +44,7 @@ torch_compilable_check, ) from ...utils.deprecation import deprecate_kwarg -from ...utils.generic import is_flash_attention_requested, merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, is_flash_attention_requested, merge_with_config_defaults from ...utils.hub import cached_file from ...utils.output_capturing import capture_outputs from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids, get_vision_window_index @@ -1758,14 +1758,13 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1775,22 +1774,15 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual( - pixel_values_videos, - grid_thw=video_grid_thw, - cu_seqlens=video_cu_seqlens, - position_ids=video_position_ids, - **kwargs, - ) + return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1800,13 +1792,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - return self.visual( - pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - **kwargs, - ) + return self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) @can_return_tuple @auto_docstring @@ -1836,7 +1822,7 @@ def get_audio_features( ) feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) audio_outputs = self.audio_tower( - input_features, feature_lens=feature_lens, aftercnn_lens=audio_feat_lengths, return_dict=True, **kwargs + input_features, feature_lens=feature_lens, aftercnn_lens=audio_feat_lengths, **kwargs ) if audio_outputs.last_hidden_state.shape[0] != sum(audio_output_lengths.tolist()): raise ValueError("length of audio_features should match audio_output_lengths") diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index a226dd043193..9860e30a0e27 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -45,7 +45,12 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.output_capturing import capture_outputs from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids, get_vision_window_index from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig @@ -1102,14 +1107,13 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1119,27 +1123,20 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - vision_outputs = self.visual( - pixel_values_videos, - grid_thw=video_grid_thw, - cu_seqlens=video_cu_seqlens, - position_ids=video_position_ids, - **kwargs, - ) + vision_outputs = self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds return vision_outputs + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1149,13 +1146,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual( - pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - **kwargs, - ) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds @@ -1268,10 +1259,6 @@ def forward( rope_deltas: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, second_per_grid_ts: torch.Tensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen2_5_VLModelOutputWithPast: r""" @@ -1289,12 +1276,7 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features( - pixel_values, - image_grid_thw, - image_cu_seqlens, - image_position_ids, - ).pooler_output + image_embeds = self.get_image_features(pixel_values, image_grid_thw, **kwargs).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds @@ -1302,12 +1284,7 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features( - pixel_values_videos, - video_grid_thw, - video_cu_seqlens, - video_position_ids, - ).pooler_output + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, **kwargs).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds @@ -1392,13 +1369,12 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1406,24 +1382,15 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). """ - return self.model.get_video_features( - pixel_values_videos, - video_grid_thw, - video_cu_seqlens, - video_position_ids, - **kwargs, - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1431,16 +1398,8 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). """ - return self.model.get_image_features( - pixel_values, - image_grid_thw, - image_cu_seqlens, - image_position_ids, - **kwargs, - ) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index f5911989aacd..9fcbf66b784f 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -518,10 +518,6 @@ def forward( rope_deltas: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, second_per_grid_ts: torch.Tensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen2_5_VLModelOutputWithPast: r""" @@ -539,12 +535,7 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features( - pixel_values, - image_grid_thw, - image_cu_seqlens, - image_position_ids, - ).pooler_output + image_embeds = self.get_image_features(pixel_values, image_grid_thw, **kwargs).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds @@ -552,12 +543,7 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features( - pixel_values_videos, - video_grid_thw, - video_cu_seqlens, - video_position_ids, - ).pooler_output + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, **kwargs).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index a2106844e76e..69f9c03bca94 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -48,6 +48,7 @@ torch_compilable_check, ) from ...utils.generic import ( + handle_extra_kwargs, is_flash_attention_requested, maybe_autocast, merge_with_config_defaults, @@ -1077,14 +1078,13 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1094,27 +1094,20 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - vision_outputs = self.visual( - pixel_values_videos, - grid_thw=video_grid_thw, - cu_seqlens=video_cu_seqlens, - position_ids=video_position_ids, - **kwargs, - ) + vision_outputs = self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds return vision_outputs + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1124,13 +1117,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual( - pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - **kwargs, - ) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds @@ -1243,10 +1230,6 @@ def forward( video_grid_thw: torch.LongTensor | None = None, rope_deltas: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen2VLModelOutputWithPast: r""" @@ -1262,12 +1245,7 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features( - pixel_values, - image_grid_thw, - image_cu_seqlens, - image_position_ids, - ).pooler_output + image_embeds = self.get_image_features(pixel_values, image_grid_thw, **kwargs).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds @@ -1275,12 +1253,7 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features( - pixel_values_videos, - video_grid_thw, - video_cu_seqlens, - video_position_ids, - ).pooler_output + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, **kwargs).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds @@ -1332,13 +1305,12 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1346,24 +1318,15 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). """ - return self.model.get_video_features( - pixel_values_videos, - video_grid_thw, - video_cu_seqlens, - video_position_ids, - **kwargs, - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1371,16 +1334,8 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). """ - return self.model.get_image_features( - pixel_values, - image_grid_thw, - image_cu_seqlens, - image_position_ids, - **kwargs, - ) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple @auto_docstring @@ -1399,10 +1354,6 @@ def forward( video_grid_thw: torch.LongTensor | None = None, rope_deltas: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen2VLCausalLMOutputWithPast: @@ -1417,10 +1368,6 @@ def forward( The temporal, height and width of feature shape of each video in LLM. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): The rope index difference between sequence length and multimodal rope. - image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). - video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). Example: @@ -1466,10 +1413,6 @@ def forward( image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, mm_token_type_ids=mm_token_type_ids, - image_cu_seqlens=image_cu_seqlens, - image_position_ids=image_position_ids, - video_cu_seqlens=video_cu_seqlens, - video_position_ids=video_position_ids, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index fce209f5edd7..1ab437765c99 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -46,7 +46,12 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available from ...utils.output_capturing import capture_outputs from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids @@ -1413,14 +1418,13 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1430,22 +1434,13 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ # Same implementation as for images - return self.get_image_features( - pixel_values_videos, - video_grid_thw, - image_cu_seqlens=video_cu_seqlens, - image_position_ids=video_position_ids, - **kwargs, - ) + return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) - @can_return_tuple - @auto_docstring + @handle_extra_kwargs(modality="image") def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1455,14 +1450,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_output: BaseModelOutputWithPooling = self.visual( - pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - return_dict=True, - **kwargs, - ) + vision_output: BaseModelOutputWithPooling = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(image_embeds, split_sizes) @@ -1774,8 +1762,6 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1783,24 +1769,14 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). """ - return self.model.get_video_features( - pixel_values_videos, - video_grid_thw, - video_cu_seqlens, - video_position_ids, - **kwargs, - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1808,16 +1784,8 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). """ - return self.model.get_image_features( - pixel_values, - image_grid_thw, - image_cu_seqlens, - image_position_ids, - **kwargs, - ) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple def forward( diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index 6873cf8b7edc..80afae422310 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -28,7 +28,7 @@ from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids from ..qwen3.modeling_qwen3 import Qwen3ForCausalLM @@ -575,23 +575,15 @@ def get_video_features( # Same implementation as for images return super().get_video_features(**super_kwargs) + @handle_extra_kwargs(modality="image") def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: pixel_values = pixel_values.type(self.visual.dtype) - vision_output: BaseModelOutputWithPooling = self.visual( - pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - return_dict=True, - **kwargs, - ) + vision_output: BaseModelOutputWithPooling = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(image_embeds, split_sizes) diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 03d2b4db8ce0..a9a0ea52d0e9 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -47,7 +47,12 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available from ...utils.output_capturing import OutputRecorder, capture_outputs from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids @@ -1538,14 +1543,13 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1555,22 +1559,13 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ # Same implementation as for images - return self.get_image_features( - pixel_values_videos, - video_grid_thw, - image_cu_seqlens=video_cu_seqlens, - image_position_ids=video_position_ids, - **kwargs, - ) + return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) - @can_return_tuple - @auto_docstring + @handle_extra_kwargs(modality="image") def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1580,14 +1575,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_output: BaseModelOutputWithPooling = self.visual( - pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - return_dict=True, - **kwargs, - ) + vision_output: BaseModelOutputWithPooling = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(image_embeds, split_sizes) @@ -1973,8 +1961,6 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -1982,24 +1968,14 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). """ - return self.model.get_video_features( - pixel_values_videos, - video_grid_thw, - video_cu_seqlens, - video_position_ids, - **kwargs, - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -2007,16 +1983,8 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). """ - return self.model.get_image_features( - pixel_values, - image_grid_thw, - image_cu_seqlens, - image_position_ids, - **kwargs, - ) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple def forward( diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 5c6a03753d68..c56fd12e7836 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -57,6 +57,7 @@ from ...utils import auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import ( TransformersKwargs, + handle_extra_kwargs, is_flash_attention_requested, maybe_autocast, merge_with_config_defaults, @@ -1931,14 +1932,13 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1948,22 +1948,15 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual( - pixel_values_videos, - grid_thw=video_grid_thw, - cu_seqlens=video_cu_seqlens, - position_ids=video_position_ids, - **kwargs, - ) + return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1973,13 +1966,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - return self.visual( - pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - **kwargs, - ) + return self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) @can_return_tuple @auto_docstring @@ -2005,7 +1992,7 @@ def get_audio_features( audio_feature_lengths = None feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) - audio_outputs = self.audio_tower(input_features, feature_lens=feature_lens, return_dict=True, **kwargs) + audio_outputs = self.audio_tower(input_features, feature_lens=feature_lens, **kwargs) return audio_outputs diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index 195a570a2bc2..489504e6de12 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -46,7 +46,7 @@ from ...processing_utils import ProcessorMixin, Unpack from ...tokenization_utils_base import TextInput from ...utils import auto_docstring, can_return_tuple, logging -from ...utils.generic import TransformersKwargs, merge_with_config_defaults +from ...utils.generic import TransformersKwargs, handle_extra_kwargs, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs from ...video_utils import VideoInput, make_batched_videos from ..mimi.modeling_mimi import MimiLayerScale @@ -1180,14 +1180,13 @@ def __init__(self, config): self.num_experts_per_tok = config.text_config.num_experts_per_tok self.router_aux_loss_coef = config.text_config.router_aux_loss_coef + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1197,22 +1196,15 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual( - pixel_values_videos, - grid_thw=video_grid_thw, - cu_seqlens=video_cu_seqlens, - position_ids=video_position_ids, - **kwargs, - ) + return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1222,13 +1214,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - return self.visual( - pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - **kwargs, - ) + return self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) @can_return_tuple @auto_docstring @@ -1254,7 +1240,7 @@ def get_audio_features( audio_feature_lengths = None feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) - audio_outputs = self.audio_tower(input_features, feature_lens=feature_lens, return_dict=True, **kwargs) + audio_outputs = self.audio_tower(input_features, feature_lens=feature_lens, **kwargs) return audio_outputs diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index f2a4182f6a8b..bb44b2d90fcb 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -40,7 +40,12 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.output_capturing import capture_outputs from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids from .configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig @@ -1055,14 +1060,13 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1072,22 +1076,15 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ # Same implementation as for images - return self.get_image_features( - pixel_values_videos, - video_grid_thw, - image_cu_seqlens=video_cu_seqlens, - image_position_ids=video_position_ids, - **kwargs, - ) + return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1098,12 +1095,7 @@ def get_image_features( """ pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithDeepstackFeatures = self.visual( - pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - return_dict=True, - **kwargs, + pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs ) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -1216,10 +1208,6 @@ def forward( image_grid_thw: torch.LongTensor | None = None, video_grid_thw: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen3VLModelOutputWithPast: r""" @@ -1239,11 +1227,7 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features( - pixel_values, - image_grid_thw, - image_cu_seqlens, - image_position_ids, - return_dict=True, + pixel_values, image_grid_thw, **kwargs ) image_embeds = image_outputs.pooler_output deepstack_image_embeds = image_outputs.deepstack_features @@ -1255,11 +1239,7 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features( - pixel_values_videos, - video_grid_thw, - video_cu_seqlens, - video_position_ids, - return_dict=True, + pixel_values_videos, video_grid_thw, **kwargs ) video_embeds = video_outputs.pooler_output deepstack_video_embeds = video_outputs.deepstack_features @@ -1374,8 +1354,6 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1383,24 +1361,14 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). """ - return self.model.get_video_features( - pixel_values_videos, - video_grid_thw, - video_cu_seqlens, - video_position_ids, - **kwargs, - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1408,16 +1376,8 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). """ - return self.model.get_image_features( - pixel_values, - image_grid_thw, - image_cu_seqlens, - image_position_ids, - **kwargs, - ) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple def forward( diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index d41b3b800ac6..ae2d0184dc3d 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -37,7 +37,7 @@ from ...processing_utils import ProcessingKwargs, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, can_return_tuple, logging -from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ...video_utils import VideoInput from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids @@ -705,14 +705,13 @@ def get_rope_index( return super().get_rope_index(video_grid_thw=video_grid_thw, **super_kwargs) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -723,12 +722,7 @@ def get_image_features( """ pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithDeepstackFeatures = self.visual( - pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - return_dict=True, - **kwargs, + pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs ) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -737,14 +731,13 @@ def get_image_features( return vision_output + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -754,13 +747,7 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ # Same implementation as for images - return self.get_image_features( - pixel_values_videos, - video_grid_thw, - image_cu_seqlens=video_cu_seqlens, - image_position_ids=video_position_ids, - **kwargs, - ) + return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) @auto_docstring @can_return_tuple @@ -776,10 +763,6 @@ def forward( image_grid_thw: torch.LongTensor | None = None, video_grid_thw: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen3VLModelOutputWithPast: r""" @@ -799,11 +782,7 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features( - pixel_values, - image_grid_thw, - image_cu_seqlens, - image_position_ids, - return_dict=True, + pixel_values, image_grid_thw, **kwargs ) image_embeds = image_outputs.pooler_output deepstack_image_embeds = image_outputs.deepstack_features @@ -815,11 +794,7 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features( - pixel_values_videos, - video_grid_thw, - video_cu_seqlens, - video_position_ids, - return_dict=True, + pixel_values_videos, video_grid_thw, **kwargs ) video_embeds = video_outputs.pooler_output deepstack_video_embeds = video_outputs.deepstack_features diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 43df218ea2bc..590bae746b72 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -46,7 +46,12 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.output_capturing import OutputRecorder, capture_outputs from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids from .configuration_qwen3_vl_moe import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig, Qwen3VLMoeVisionConfig @@ -1185,14 +1190,13 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1202,22 +1206,15 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ # Same implementation as for images - return self.get_image_features( - pixel_values_videos, - video_grid_thw, - image_cu_seqlens=video_cu_seqlens, - image_position_ids=video_position_ids, - **kwargs, - ) + return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1228,12 +1225,7 @@ def get_image_features( """ pixel_values = pixel_values.type(self.visual.dtype) vision_output: BaseModelOutputWithDeepstackFeatures = self.visual( - pixel_values, - grid_thw=image_grid_thw, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - return_dict=True, - **kwargs, + pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs ) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() @@ -1346,10 +1338,6 @@ def forward( image_grid_thw: torch.LongTensor | None = None, video_grid_thw: torch.LongTensor | None = None, mm_token_type_ids: torch.IntTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | Qwen3VLMoeModelOutputWithPast: r""" @@ -1369,11 +1357,7 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features( - pixel_values, - image_grid_thw, - image_cu_seqlens, - image_position_ids, - return_dict=True, + pixel_values, image_grid_thw, **kwargs ) image_embeds = image_outputs.pooler_output deepstack_image_embeds = image_outputs.deepstack_features @@ -1385,11 +1369,7 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features( - pixel_values_videos, - video_grid_thw, - video_cu_seqlens, - video_position_ids, - return_dict=True, + pixel_values_videos, video_grid_thw, **kwargs ) video_embeds = video_outputs.pooler_output deepstack_video_embeds = video_outputs.deepstack_features @@ -1557,8 +1537,6 @@ def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1566,24 +1544,14 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. - video_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for videos (from `get_vision_cu_seqlens`). """ - return self.model.get_video_features( - pixel_values_videos, - video_grid_thw, - video_cu_seqlens, - video_position_ids, - **kwargs, - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: r""" @@ -1591,16 +1559,8 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - image_cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths for images (from `get_vision_cu_seqlens`). """ - return self.model.get_image_features( - pixel_values, - image_grid_thw, - image_cu_seqlens, - image_position_ids, - **kwargs, - ) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple def forward( diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index fdad9f96a9bc..b12d5e4173c8 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -34,7 +34,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check -from ...utils.generic import is_flash_attention_requested, merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, is_flash_attention_requested, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from ..auto.modeling_auto import AutoModel @@ -522,6 +522,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -529,8 +530,6 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor, video_merge_sizes: torch.LongTensor, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -545,11 +544,10 @@ def get_video_features( pixel_values=pixel_values_videos, image_grid_thw=video_grid_thw, image_merge_sizes=video_merge_sizes, - image_cu_seqlens=video_cu_seqlens, - image_position_ids=video_position_ids, **kwargs, ) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -557,8 +555,6 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor, image_merge_sizes: torch.LongTensor, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -570,13 +566,7 @@ def get_image_features( The spatial downsampling ratio of each image feature. """ vision_outputs = self.vision_model( - pixel_values=pixel_values, - grid_thw=image_grid_thw, - merge_sizes=image_merge_sizes, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - return_dict=True, - **kwargs, + pixel_values=pixel_values, grid_thw=image_grid_thw, merge_sizes=image_merge_sizes, **kwargs ) last_hidden_state = vision_outputs.last_hidden_state image_embeds = self.projector(last_hidden_state) @@ -758,6 +748,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -765,8 +756,6 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_merge_sizes: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -777,15 +766,9 @@ def get_video_features( video_merge_sizes (`torch.Tensor` of shape `(num_videos,)`, *optional*): The spatial downsampling ratio of each video feature. """ - return self.model.get_video_features( - pixel_values_videos, - video_grid_thw, - video_merge_sizes, - video_cu_seqlens=video_cu_seqlens, - video_position_ids=video_position_ids, - **kwargs, - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, video_merge_sizes, **kwargs) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -793,8 +776,6 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_merge_sizes: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -805,14 +786,7 @@ def get_image_features( image_merge_sizes (`torch.Tensor` of shape `(num_images,)`, *optional*): The spatial downsampling ratio of each image feature. """ - return self.model.get_image_features( - pixel_values, - image_grid_thw, - image_merge_sizes, - image_cu_seqlens=image_cu_seqlens, - image_position_ids=image_position_ids, - **kwargs, - ) + return self.model.get_image_features(pixel_values, image_grid_thw, image_merge_sizes, **kwargs) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index adb7180a32ce..b7db0f475662 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -45,7 +45,7 @@ can_return_tuple, logging, ) -from ...utils.generic import is_flash_attention_requested, merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, is_flash_attention_requested, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ...video_utils import ( VideoInput, @@ -483,6 +483,7 @@ def get_vision_position_ids(self): def compute_3d_position_ids(self): raise AttributeError("Not needed for VideoLLaMA3") + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -490,8 +491,6 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor, video_merge_sizes: torch.LongTensor, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -506,11 +505,10 @@ def get_video_features( pixel_values=pixel_values_videos, image_grid_thw=video_grid_thw, image_merge_sizes=video_merge_sizes, - image_cu_seqlens=video_cu_seqlens, - image_position_ids=video_position_ids, **kwargs, ) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -518,8 +516,6 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor, image_merge_sizes: torch.LongTensor, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -531,13 +527,7 @@ def get_image_features( The spatial downsampling ratio of each image feature. """ vision_outputs = self.vision_model( - pixel_values=pixel_values, - grid_thw=image_grid_thw, - merge_sizes=image_merge_sizes, - cu_seqlens=image_cu_seqlens, - position_ids=image_position_ids, - return_dict=True, - **kwargs, + pixel_values=pixel_values, grid_thw=image_grid_thw, merge_sizes=image_merge_sizes, **kwargs ) last_hidden_state = vision_outputs.last_hidden_state image_embeds = self.projector(last_hidden_state) @@ -667,6 +657,7 @@ class VideoLlama3ForConditionalGeneration(Qwen2VLForConditionalGeneration): def __init__(self, config: VideoLlama3Config): super().__init__(config) # just to add type hint on config + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -674,8 +665,6 @@ def get_image_features( pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, image_merge_sizes: torch.LongTensor | None = None, - image_cu_seqlens: torch.Tensor | None = None, - image_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -686,15 +675,9 @@ def get_image_features( image_merge_sizes (`torch.Tensor` of shape `(num_images,)`, *optional*): The spatial downsampling ratio of each image feature. """ - return self.model.get_image_features( - pixel_values, - image_grid_thw, - image_merge_sizes, - image_cu_seqlens=image_cu_seqlens, - image_position_ids=image_position_ids, - **kwargs, - ) + return self.model.get_image_features(pixel_values, image_grid_thw, image_merge_sizes, **kwargs) + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -702,8 +685,6 @@ def get_video_features( pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, video_merge_sizes: torch.LongTensor | None = None, - video_cu_seqlens: torch.Tensor | None = None, - video_position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -714,14 +695,7 @@ def get_video_features( video_merge_sizes (`torch.Tensor` of shape `(num_videos,)`, *optional*): The spatial downsampling ratio of each video feature. """ - return self.model.get_video_features( - pixel_values_videos, - video_grid_thw, - video_merge_sizes, - video_cu_seqlens=video_cu_seqlens, - video_position_ids=video_position_ids, - **kwargs, - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, video_merge_sizes, **kwargs) @can_return_tuple @auto_docstring diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py index 619de6f0122f..bd04f3fb901e 100644 --- a/src/transformers/utils/auto_docstring.py +++ b/src/transformers/utils/auto_docstring.py @@ -2121,34 +2121,6 @@ class ModelArgs: "shape": "of shape `(batch_size, num_frames, num_channels, frame_size, frame_size)`", } - image_cu_seqlens = { - "description": """ - Precomputed cumulative sequence lengths for image patches, used for packed variable-length attention. - """, - "shape": "of shape `(num_image_patches + 1,)`", - } - - video_cu_seqlens = { - "description": """ - Precomputed cumulative sequence lengths for video patches, used for packed variable-length attention. - """, - "shape": "of shape `(num_video_patches + 1,)`", - } - - image_position_ids = { - "description": """ - Precomputed (row, col) position IDs for image rotary embeddings. - """, - "shape": "of shape `(num_image_tokens, 2)`", - } - - video_position_ids = { - "description": """ - Precomputed (row, col) position IDs for video rotary embeddings. - """, - "shape": "of shape `(num_video_tokens, 2)`", - } - vision_feature_layer = { "description": """ The index of the layer to select the vision feature. If multiple indices are provided, diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 56bf5a47d0a6..a95b287200be 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -892,6 +892,45 @@ def wrapper(self, *args, **kwargs): return wrapper +_KNOWN_MODALITIES = ("image", "video", "audio") + + +def handle_extra_kwargs(modality: str): + """ + Decorator for ``get__features`` methods that: + - strips the modality prefix from incoming kwargs whose stripped name isn't an existing + parameter (e.g. ``image_cu_seqlens`` → ``cu_seqlens``, forwarded via ``**kwargs``); + - drops kwargs prefixed with another known modality (e.g. ``video_*`` passed to an + image method), so an outer ``forward()`` can blindly forward ``**kwargs`` to each + modality method without leaking the wrong tensors into the wrong encoder; + - leaves everything else untouched (including kwargs that match a named parameter). + + Used so multimodal models can accept arbitrary precomputed tensors (``image_cu_seqlens``, + ``video_position_ids``, …) without enumerating each one in every signature. + """ + prefix = f"{modality}_" + other_prefixes = tuple(f"{m}_" for m in _KNOWN_MODALITIES if m != modality) + + def decorator(func): + existing_params = set(inspect.signature(func).parameters) + + @wraps(func) + def wrapper(*args, **kwargs): + translated = {} + for k, v in kwargs.items(): + if k.startswith(other_prefixes): + continue + if k.startswith(prefix) and k not in existing_params: + translated[k.removeprefix(prefix)] = v + else: + translated[k] = v + return func(*args, **translated) + + return wrapper + + return decorator + + def merge_with_config_defaults(func): """ Decorator using config field (if they exist) as default value for some args and kwargs. Precedence is always From 6d33d4a1713b855cca2d422540665040dd64dacf Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 27 Apr 2026 23:37:36 +0200 Subject: [PATCH 48/56] pass kwargs and return_dict --- .../ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py | 12 ++++++++---- .../ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py | 12 ++++++++---- .../models/glm46v/modeling_glm46v.py | 14 ++++++++++---- src/transformers/models/glm4v/modeling_glm4v.py | 14 ++++++++++---- src/transformers/models/glm4v/modular_glm4v.py | 12 +++++++++--- .../models/glm4v_moe/modeling_glm4v_moe.py | 14 ++++++++++---- .../models/glm_image/modeling_glm_image.py | 4 ++-- .../models/glm_image/modular_glm_image.py | 4 ++-- .../models/glm_ocr/modeling_glm_ocr.py | 14 ++++++++++---- .../paddleocr_vl/modeling_paddleocr_vl.py | 6 ++++-- .../models/paddleocr_vl/modular_paddleocr_vl.py | 6 ++++-- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 17 +++++++++-------- .../models/qwen2_5_omni/modular_qwen2_5_omni.py | 17 +++++++++-------- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 4 ++-- .../models/qwen2_vl/modeling_qwen2_vl.py | 4 ++-- .../models/qwen3_5/modeling_qwen3_5.py | 8 +++++--- .../models/qwen3_5/modular_qwen3_5.py | 8 +++++--- .../models/qwen3_5_moe/modeling_qwen3_5_moe.py | 8 +++++--- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 15 ++++++--------- .../qwen3_omni_moe/modular_qwen3_omni_moe.py | 15 ++++++--------- .../models/qwen3_vl/modeling_qwen3_vl.py | 4 ++-- .../models/qwen3_vl/modular_qwen3_vl.py | 4 ++-- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 4 ++-- .../video_llama_3/modeling_video_llama_3.py | 10 +++++++--- .../video_llama_3/modular_video_llama_3.py | 10 +++++++--- 25 files changed, 146 insertions(+), 94 deletions(-) diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index 1eea462ceeca..bcaf2ac9cbf0 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -1270,7 +1270,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - video_outputs = self.vision_tower(pixel_values_videos, video_grid_thw, **kwargs) + video_outputs = self.vision_tower(pixel_values_videos, video_grid_thw, return_dict=True, **kwargs) video_embeds = self.resampler_model(video_outputs.last_hidden_state, video_grid_thw) split_sizes = ( video_grid_thw.prod(-1) @@ -1296,7 +1296,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - image_outputs = self.vision_tower(pixel_values, image_grid_thw, **kwargs) + image_outputs = self.vision_tower(pixel_values, image_grid_thw, return_dict=True, **kwargs) image_embeds = self.resampler_model(image_outputs.last_hidden_state, image_grid_thw) split_sizes = (image_grid_thw.prod(-1) // self.vision_tower.spatial_merge_size**2).tolist() image_embeds = torch.split(image_embeds, split_sizes) @@ -1428,7 +1428,9 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds @@ -1436,7 +1438,9 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs + ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds diff --git a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py index d90b601749c8..d370c404fa30 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py @@ -973,7 +973,7 @@ def get_video_features( video_grid_thw: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: - video_outputs = self.vision_tower(pixel_values_videos, video_grid_thw, **kwargs) + video_outputs = self.vision_tower(pixel_values_videos, video_grid_thw, return_dict=True, **kwargs) video_embeds = self.resampler_model(video_outputs.last_hidden_state, video_grid_thw) split_sizes = ( video_grid_thw.prod(-1) @@ -993,7 +993,7 @@ def get_image_features( image_grid_thw: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: - image_outputs = self.vision_tower(pixel_values, image_grid_thw, **kwargs) + image_outputs = self.vision_tower(pixel_values, image_grid_thw, return_dict=True, **kwargs) image_embeds = self.resampler_model(image_outputs.last_hidden_state, image_grid_thw) split_sizes = (image_grid_thw.prod(-1) // self.vision_tower.spatial_merge_size**2).tolist() image_embeds = torch.split(image_embeds, split_sizes) @@ -1035,7 +1035,9 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds @@ -1043,7 +1045,9 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs + ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds diff --git a/src/transformers/models/glm46v/modeling_glm46v.py b/src/transformers/models/glm46v/modeling_glm46v.py index 3f5f2dd4cd12..90aa1bd8a568 100644 --- a/src/transformers/models/glm46v/modeling_glm46v.py +++ b/src/transformers/models/glm46v/modeling_glm46v.py @@ -278,7 +278,9 @@ def get_video_features( flattened_hw = torch.repeat_interleave(hw, t, dim=0) prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) - vision_outputs = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw, **kwargs) + vision_outputs = self.visual( + pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs + ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds @@ -301,7 +303,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds @@ -431,13 +433,17 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs + ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index db57222e5c6d..8461bb4b7600 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -1112,7 +1112,9 @@ def get_video_features( flattened_hw = torch.repeat_interleave(hw, t, dim=0) prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) - vision_outputs = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw, **kwargs) + vision_outputs = self.visual( + pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs + ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds @@ -1135,7 +1137,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds @@ -1265,13 +1267,17 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs + ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 0e1aa1927c3e..cdaebf121d78 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -816,7 +816,9 @@ def get_video_features( flattened_hw = torch.repeat_interleave(hw, t, dim=0) prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) - vision_outputs = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw, **kwargs) + vision_outputs = self.visual( + pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs + ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds @@ -934,13 +936,17 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs + ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 3428c6d5c414..65df83f24463 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -1282,7 +1282,9 @@ def get_video_features( flattened_hw = torch.repeat_interleave(hw, t, dim=0) prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) - vision_outputs = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw, **kwargs) + vision_outputs = self.visual( + pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs + ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds @@ -1305,7 +1307,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds @@ -1435,13 +1437,17 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs + ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py index 56a583dca040..26ec9892bbb0 100644 --- a/src/transformers/models/glm_image/modeling_glm_image.py +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -1191,7 +1191,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.last_hidden_state, split_sizes) vision_outputs.pooler_output = image_embeds @@ -1333,7 +1333,7 @@ def forward( # Fallback for batch_size=1: all but last grid are source images source_grids = image_grid_thw[:-1] - image_features = self.get_image_features(pixel_values, source_grids, return_dict=True) + image_features = self.get_image_features(pixel_values, source_grids, return_dict=True, **kwargs) image_embeds = torch.cat(image_features.pooler_output, dim=0) image_ids = self.get_image_tokens(image_embeds, source_grids) image_ids = image_ids.view(-1).to(input_ids.device) diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index c6ae68ad196c..33d372fdd249 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -728,7 +728,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.last_hidden_state, split_sizes) vision_outputs.pooler_output = image_embeds @@ -868,7 +868,7 @@ def forward( # Fallback for batch_size=1: all but last grid are source images source_grids = image_grid_thw[:-1] - image_features = self.get_image_features(pixel_values, source_grids, return_dict=True) + image_features = self.get_image_features(pixel_values, source_grids, return_dict=True, **kwargs) image_embeds = torch.cat(image_features.pooler_output, dim=0) image_ids = self.get_image_tokens(image_embeds, source_grids) image_ids = image_ids.view(-1).to(input_ids.device) diff --git a/src/transformers/models/glm_ocr/modeling_glm_ocr.py b/src/transformers/models/glm_ocr/modeling_glm_ocr.py index a8f1fdb1eef3..c68954ba59d7 100644 --- a/src/transformers/models/glm_ocr/modeling_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modeling_glm_ocr.py @@ -1029,7 +1029,9 @@ def get_video_features( flattened_hw = torch.repeat_interleave(hw, t, dim=0) prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) - vision_outputs = self.visual(pixel_values_videos, grid_thw=flattened_video_grid_thw, **kwargs) + vision_outputs = self.visual( + pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs + ) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds @@ -1052,7 +1054,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds @@ -1182,13 +1184,17 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs + ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index d332278336c9..b148d97304bf 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -1230,7 +1230,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype).unsqueeze(0) - vision_outputs = self.visual(pixel_values=pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values=pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) image_embeds = vision_outputs.last_hidden_state image_embeds = self.projector(image_embeds, image_grid_thw) vision_outputs.pooler_output = image_embeds @@ -1335,7 +1335,9 @@ def forward( inputs_embeds = self.language_model.embed_tokens(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index 487f73da4097..43458db2d92d 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -977,7 +977,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype).unsqueeze(0) - vision_outputs = self.visual(pixel_values=pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values=pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) image_embeds = vision_outputs.last_hidden_state image_embeds = self.projector(image_embeds, image_grid_thw) vision_outputs.pooler_output = image_embeds @@ -1033,7 +1033,9 @@ def forward( inputs_embeds = self.language_model.embed_tokens(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 9e0cd73c1497..be4742830e51 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1794,7 +1794,7 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + return self.visual(pixel_values_videos, grid_thw=video_grid_thw, return_dict=True, **kwargs) @handle_extra_kwargs(modality="image") @can_return_tuple @@ -1812,7 +1812,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - return self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + return self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) @can_return_tuple @auto_docstring @@ -1985,17 +1985,16 @@ def forward( # 2. Merge text , audios , image and video if input_features is not None: audio_features = self.get_audio_features( - input_features, - feature_attention_mask=feature_attention_mask, - audio_feature_lengths=audio_feature_lengths, - return_dict=True, + input_features, feature_attention_mask, audio_feature_lengths, return_dict=True, **kwargs ).last_hidden_state audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) _, _, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds @@ -2003,7 +2002,9 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs + ).pooler_output video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index eb6254589d7e..be82c77a87cc 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1774,7 +1774,7 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + return self.visual(pixel_values_videos, grid_thw=video_grid_thw, return_dict=True, **kwargs) @handle_extra_kwargs(modality="image") @can_return_tuple @@ -1792,7 +1792,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - return self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + return self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) @can_return_tuple @auto_docstring @@ -1965,17 +1965,16 @@ def forward( # 2. Merge text , audios , image and video if input_features is not None: audio_features = self.get_audio_features( - input_features, - feature_attention_mask=feature_attention_mask, - audio_feature_lengths=audio_feature_lengths, - return_dict=True, + input_features, feature_attention_mask, audio_feature_lengths, return_dict=True, **kwargs ).last_hidden_state audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) _, _, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds @@ -1983,7 +1982,9 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs + ).pooler_output video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 9860e30a0e27..c8a76f8982fc 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1123,7 +1123,7 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values_videos, grid_thw=video_grid_thw, return_dict=True, **kwargs) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds @@ -1146,7 +1146,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 69f9c03bca94..897d47c874e8 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1094,7 +1094,7 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values_videos, grid_thw=video_grid_thw, return_dict=True, **kwargs) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds @@ -1117,7 +1117,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index dfd1f802af24..2fcf9890b5b3 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1459,7 +1459,9 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_output: BaseModelOutputWithPooling = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_output: BaseModelOutputWithPooling = self.visual( + pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs + ) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(image_embeds, split_sizes) @@ -1587,7 +1589,7 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithPooling = self.get_image_features( - pixel_values, image_grid_thw, return_dict=True + pixel_values, image_grid_thw, return_dict=True, **kwargs ) image_embeds = image_outputs.pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) @@ -1598,7 +1600,7 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithPooling = self.get_video_features( - pixel_values_videos, video_grid_thw, return_dict=True + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs ) video_embeds = video_outputs.pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index 9d3b0052efef..6d411f6dfe0d 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -592,7 +592,9 @@ def get_image_features( **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: pixel_values = pixel_values.type(self.visual.dtype) - vision_output: BaseModelOutputWithPooling = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_output: BaseModelOutputWithPooling = self.visual( + pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs + ) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(image_embeds, split_sizes) @@ -630,7 +632,7 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithPooling = self.get_image_features( - pixel_values, image_grid_thw, return_dict=True + pixel_values, image_grid_thw, return_dict=True, **kwargs ) image_embeds = image_outputs.pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) @@ -641,7 +643,7 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithPooling = self.get_video_features( - pixel_values_videos, video_grid_thw, return_dict=True + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs ) video_embeds = video_outputs.pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 215d64254fec..c36833a477f6 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1584,7 +1584,9 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_output: BaseModelOutputWithPooling = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_output: BaseModelOutputWithPooling = self.visual( + pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs + ) image_embeds = vision_output.pooler_output split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(image_embeds, split_sizes) @@ -1712,7 +1714,7 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithPooling = self.get_image_features( - pixel_values, image_grid_thw, return_dict=True + pixel_values, image_grid_thw, return_dict=True, **kwargs ) image_embeds = image_outputs.pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) @@ -1723,7 +1725,7 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithPooling = self.get_video_features( - pixel_values_videos, video_grid_thw, return_dict=True + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs ) video_embeds = video_outputs.pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index c56fd12e7836..863c92814372 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1948,7 +1948,7 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + return self.visual(pixel_values_videos, grid_thw=video_grid_thw, return_dict=True, **kwargs) @handle_extra_kwargs(modality="image") @can_return_tuple @@ -1966,7 +1966,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - return self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + return self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) @can_return_tuple @auto_docstring @@ -1992,7 +1992,7 @@ def get_audio_features( audio_feature_lengths = None feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) - audio_outputs = self.audio_tower(input_features, feature_lens=feature_lens, **kwargs) + audio_outputs = self.audio_tower(input_features, feature_lens=feature_lens, return_dict=True, **kwargs) return audio_outputs @@ -2139,10 +2139,7 @@ def forward( # 2. Merge text , audios , image and video if input_features is not None: audio_features = self.get_audio_features( - input_features, - feature_attention_mask=feature_attention_mask, - audio_feature_lengths=audio_feature_lengths, - return_dict=True, + input_features, feature_attention_mask, audio_feature_lengths, return_dict=True, **kwargs ).last_hidden_state audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) _, _, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) @@ -2150,7 +2147,7 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features( - pixel_values, image_grid_thw, return_dict=True + pixel_values, image_grid_thw, return_dict=True, **kwargs ) image_embeds = image_outputs.pooler_output image_embeds_multiscale = image_outputs.deepstack_features @@ -2162,7 +2159,7 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features( - pixel_values_videos, video_grid_thw, return_dict=True + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs ) video_embeds = video_outputs.pooler_output video_embeds_multiscale = video_outputs.deepstack_features diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index 310506fe37ca..2ffe2dc3fb83 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -1203,7 +1203,7 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + return self.visual(pixel_values_videos, grid_thw=video_grid_thw, return_dict=True, **kwargs) @handle_extra_kwargs(modality="image") @can_return_tuple @@ -1221,7 +1221,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - return self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + return self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) @can_return_tuple @auto_docstring @@ -1247,7 +1247,7 @@ def get_audio_features( audio_feature_lengths = None feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) - audio_outputs = self.audio_tower(input_features, feature_lens=feature_lens, **kwargs) + audio_outputs = self.audio_tower(input_features, feature_lens=feature_lens, return_dict=True, **kwargs) return audio_outputs @@ -1289,10 +1289,7 @@ def forward( # 2. Merge text , audios , image and video if input_features is not None: audio_features = self.get_audio_features( - input_features, - feature_attention_mask=feature_attention_mask, - audio_feature_lengths=audio_feature_lengths, - return_dict=True, + input_features, feature_attention_mask, audio_feature_lengths, return_dict=True, **kwargs ).last_hidden_state audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) _, _, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) @@ -1300,7 +1297,7 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features( - pixel_values, image_grid_thw, return_dict=True + pixel_values, image_grid_thw, return_dict=True, **kwargs ) image_embeds = image_outputs.pooler_output image_embeds_multiscale = image_outputs.deepstack_features @@ -1312,7 +1309,7 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features( - pixel_values_videos, video_grid_thw, return_dict=True + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs ) video_embeds = video_outputs.pooler_output video_embeds_multiscale = video_outputs.deepstack_features diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index bb44b2d90fcb..517598021632 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -1227,7 +1227,7 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features( - pixel_values, image_grid_thw, **kwargs + pixel_values, image_grid_thw, return_dict=True, **kwargs ) image_embeds = image_outputs.pooler_output deepstack_image_embeds = image_outputs.deepstack_features @@ -1239,7 +1239,7 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features( - pixel_values_videos, video_grid_thw, **kwargs + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs ) video_embeds = video_outputs.pooler_output deepstack_video_embeds = video_outputs.deepstack_features diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index ae2d0184dc3d..d056351c11ed 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -782,7 +782,7 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features( - pixel_values, image_grid_thw, **kwargs + pixel_values, image_grid_thw, return_dict=True, **kwargs ) image_embeds = image_outputs.pooler_output deepstack_image_embeds = image_outputs.deepstack_features @@ -794,7 +794,7 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features( - pixel_values_videos, video_grid_thw, **kwargs + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs ) video_embeds = video_outputs.pooler_output deepstack_video_embeds = video_outputs.deepstack_features diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 590bae746b72..f7dc594a395e 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -1357,7 +1357,7 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features( - pixel_values, image_grid_thw, **kwargs + pixel_values, image_grid_thw, return_dict=True, **kwargs ) image_embeds = image_outputs.pooler_output deepstack_image_embeds = image_outputs.deepstack_features @@ -1369,7 +1369,7 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features( - pixel_values_videos, video_grid_thw, **kwargs + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs ) video_embeds = video_outputs.pooler_output deepstack_video_embeds = video_outputs.deepstack_features diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index b12d5e4173c8..be8bea0dc2ad 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -566,7 +566,11 @@ def get_image_features( The spatial downsampling ratio of each image feature. """ vision_outputs = self.vision_model( - pixel_values=pixel_values, grid_thw=image_grid_thw, merge_sizes=image_merge_sizes, **kwargs + pixel_values=pixel_values, + grid_thw=image_grid_thw, + merge_sizes=image_merge_sizes, + return_dict=True, + **kwargs, ) last_hidden_state = vision_outputs.last_hidden_state image_embeds = self.projector(last_hidden_state) @@ -655,7 +659,7 @@ def forward( image_embeds = None if pixel_values is not None: image_embeds = self.get_image_features( - pixel_values, image_grid_thw, image_merge_sizes, return_dict=True + pixel_values, image_grid_thw, image_merge_sizes, return_dict=True, **kwargs ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( @@ -666,7 +670,7 @@ def forward( video_embeds = None if pixel_values_videos is not None: video_embeds = self.get_video_features( - pixel_values_videos, video_grid_thw, video_merge_sizes, return_dict=True + pixel_values_videos, video_grid_thw, video_merge_sizes, return_dict=True, **kwargs ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) if video_compression_mask is not None: diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index b7db0f475662..b90062fbd362 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -527,7 +527,11 @@ def get_image_features( The spatial downsampling ratio of each image feature. """ vision_outputs = self.vision_model( - pixel_values=pixel_values, grid_thw=image_grid_thw, merge_sizes=image_merge_sizes, **kwargs + pixel_values=pixel_values, + grid_thw=image_grid_thw, + merge_sizes=image_merge_sizes, + return_dict=True, + **kwargs, ) last_hidden_state = vision_outputs.last_hidden_state image_embeds = self.projector(last_hidden_state) @@ -575,7 +579,7 @@ def forward( image_embeds = None if pixel_values is not None: image_embeds = self.get_image_features( - pixel_values, image_grid_thw, image_merge_sizes, return_dict=True + pixel_values, image_grid_thw, image_merge_sizes, return_dict=True, **kwargs ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( @@ -586,7 +590,7 @@ def forward( video_embeds = None if pixel_values_videos is not None: video_embeds = self.get_video_features( - pixel_values_videos, video_grid_thw, video_merge_sizes, return_dict=True + pixel_values_videos, video_grid_thw, video_merge_sizes, return_dict=True, **kwargs ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) if video_compression_mask is not None: From fe3bcc4080d122bf8d5764e8afee24af414d05a3 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 28 Apr 2026 11:11:30 +0200 Subject: [PATCH 49/56] fix missing --- src/transformers/models/qwen3_5/modeling_qwen3_5.py | 2 ++ src/transformers/models/qwen3_5/modular_qwen3_5.py | 2 ++ src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index 2fcf9890b5b3..baed2fe9b0fc 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1446,6 +1446,8 @@ def get_video_features( return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) @handle_extra_kwargs(modality="image") + @can_return_tuple + @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index 6d411f6dfe0d..ab7cd4efd664 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -585,6 +585,8 @@ def get_video_features( return super().get_video_features(**super_kwargs) @handle_extra_kwargs(modality="image") + @can_return_tuple + @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index c36833a477f6..e84062bd43c7 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1571,6 +1571,8 @@ def get_video_features( return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) @handle_extra_kwargs(modality="image") + @can_return_tuple + @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, From 51f7e206792db12ac123cd95927f8688803e654b Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 28 Apr 2026 12:00:00 +0200 Subject: [PATCH 50/56] keep in and get from kwargs --- .../modeling_ernie4_5_vl_moe.py | 16 +----- .../modular_ernie4_5_vl_moe.py | 12 +--- .../models/glm4v/modeling_glm4v.py | 16 +----- .../models/glm4v/modular_glm4v.py | 16 +----- .../models/glm4v_moe/modeling_glm4v_moe.py | 16 +----- .../models/glm_image/modeling_glm_image.py | 16 +----- .../models/glm_image/modular_glm_image.py | 16 +----- .../paddleocr_vl/modeling_paddleocr_vl.py | 38 ++----------- .../paddleocr_vl/modular_paddleocr_vl.py | 38 ++----------- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 55 ++++--------------- .../qwen2_5_omni/modular_qwen2_5_omni.py | 55 ++++--------------- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 26 ++------- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 26 ++------- .../models/qwen2_vl/modeling_qwen2_vl.py | 17 +----- .../models/qwen3_5/modeling_qwen3_5.py | 27 ++------- .../models/qwen3_5/modular_qwen3_5.py | 27 ++------- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 27 ++------- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 51 ++++------------- .../qwen3_omni_moe/modular_qwen3_omni_moe.py | 24 ++------ .../models/qwen3_vl/modeling_qwen3_vl.py | 27 ++------- .../models/qwen3_vl/modular_qwen3_vl.py | 27 ++------- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 27 ++------- .../video_llama_3/modeling_video_llama_3.py | 15 +---- .../video_llama_3/modular_video_llama_3.py | 15 +---- 24 files changed, 128 insertions(+), 502 deletions(-) diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index bcaf2ac9cbf0..fce3ce797927 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -914,30 +914,20 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" grid_thw (`torch.LongTensor` of shape `(num_images, 3)`): The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values. - cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). """ - hidden_states = self.patch_embed(hidden_states) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - for block in self.blocks: hidden_states = block( hidden_states, diff --git a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py index d370c404fa30..3b2d7721a3d9 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py @@ -702,22 +702,16 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: - hidden_states = self.patch_embed(hidden_states) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - for block in self.blocks: hidden_states = block( hidden_states, diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 8461bb4b7600..3c4266995c15 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -748,8 +748,6 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -757,23 +755,15 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor`: hidden_states. """ + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + hidden_states = self.patch_embed(hidden_states) hidden_states = self.post_conv_layernorm(hidden_states) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - rotary_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index cdaebf121d78..9dc601aecc94 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -627,8 +627,6 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -636,23 +634,15 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor`: hidden_states. """ + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + hidden_states = self.patch_embed(hidden_states) hidden_states = self.post_conv_layernorm(hidden_states) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - rotary_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 65df83f24463..b60b8870945f 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -813,8 +813,6 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -822,23 +820,15 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor`: hidden_states. """ + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + hidden_states = self.patch_embed(hidden_states) hidden_states = self.post_conv_layernorm(hidden_states) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - rotary_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py index 26ec9892bbb0..64ef7d636d5d 100644 --- a/src/transformers/models/glm_image/modeling_glm_image.py +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -605,8 +605,6 @@ def forward( self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -614,22 +612,14 @@ def forward( Packed pixel values. grid_thw (`torch.Tensor` of shape `(num_images, 3)`): The temporal, height and width of feature shape of each image. - cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states. """ - hidden_states = self.patch_embed(pixel_values) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + hidden_states = self.patch_embed(pixel_values) seqlens = cu_seqlens[1:] - cu_seqlens[:-1] hidden_states = self.embeddings( hidden_states, diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index 33d372fdd249..ba7eb08aa08e 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -445,8 +445,6 @@ def forward( self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -454,22 +452,14 @@ def forward( Packed pixel values. grid_thw (`torch.Tensor` of shape `(num_images, 3)`): The temporal, height and width of feature shape of each image. - cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states. """ - hidden_states = self.patch_embed(pixel_values) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + hidden_states = self.patch_embed(pixel_values) seqlens = cu_seqlens[1:] - cu_seqlens[:-1] hidden_states = self.embeddings( hidden_states, diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index b148d97304bf..b5e6f8e2a228 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -826,8 +826,6 @@ def forward( inputs_embeds: torch.FloatTensor, grid_thw: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: r""" @@ -839,30 +837,22 @@ def forward( The temporal, height and width of feature shape of each image in LLM. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). """ + # Use merge_size=1: PaddleOCR merges patches in the projector (after the encoder), + # unlike Qwen which merges inside the encoder, so rotary positions here are simple (row, col). + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, 1) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + hidden_states = inputs_embeds attention_mask = create_bidirectional_mask( config=self.config, inputs_embeds=inputs_embeds, attention_mask=attention_mask, ) - - if position_ids is None: - # Use merge_size=1: PaddleOCR merges patches in the projector (after the encoder), - # unlike Qwen which merges inside the encoder, so rotary positions here are simple (row, col). - position_ids = get_vision_position_ids(grid_thw, 1) - rotary_embeddings = self.rotary_pos_emb(position_ids) rotary_embeddings = rotary_embeddings.repeat(1, 2) position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin()) - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, @@ -902,9 +892,7 @@ def forward( self, pixel_values: torch.FloatTensor, grid_thw: torch.LongTensor | None = None, - cu_seqlens: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: """ @@ -915,16 +903,12 @@ def forward( The attention_mask used in forward function shape [batch_size X sequence_length] if not None. grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`, *optional*): - Precomputed cumulative sequence lengths. """ hidden_states = self.embeddings(pixel_values, grid_thw=grid_thw) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, grid_thw=grid_thw, - cu_seqlens=cu_seqlens, attention_mask=attention_mask, - position_ids=position_ids, **kwargs, ) @@ -954,8 +938,6 @@ def forward( self, pixel_values: torch.FloatTensor, grid_thw: torch.LongTensor | None = None, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ @@ -964,16 +946,8 @@ def forward( The tensors corresponding to the input images. grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): - The cumulative sequence lengths of each image or video feature. """ - return self.vision_model( - pixel_values=pixel_values, - grid_thw=grid_thw, - cu_seqlens=cu_seqlens, - position_ids=position_ids, - **kwargs, - ) + return self.vision_model(pixel_values=pixel_values, grid_thw=grid_thw, **kwargs) @dataclass diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index 43458db2d92d..a5336520c341 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -782,8 +782,6 @@ def forward( inputs_embeds: torch.FloatTensor, grid_thw: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: r""" @@ -795,30 +793,22 @@ def forward( The temporal, height and width of feature shape of each image in LLM. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). """ + # Use merge_size=1: PaddleOCR merges patches in the projector (after the encoder), + # unlike Qwen which merges inside the encoder, so rotary positions here are simple (row, col). + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, 1) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + hidden_states = inputs_embeds attention_mask = create_bidirectional_mask( config=self.config, inputs_embeds=inputs_embeds, attention_mask=attention_mask, ) - - if position_ids is None: - # Use merge_size=1: PaddleOCR merges patches in the projector (after the encoder), - # unlike Qwen which merges inside the encoder, so rotary positions here are simple (row, col). - position_ids = get_vision_position_ids(grid_thw, 1) - rotary_embeddings = self.rotary_pos_emb(position_ids) rotary_embeddings = rotary_embeddings.repeat(1, 2) position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin()) - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, @@ -858,9 +848,7 @@ def forward( self, pixel_values: torch.FloatTensor, grid_thw: torch.LongTensor | None = None, - cu_seqlens: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: """ @@ -871,16 +859,12 @@ def forward( The attention_mask used in forward function shape [batch_size X sequence_length] if not None. grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`, *optional*): - Precomputed cumulative sequence lengths. """ hidden_states = self.embeddings(pixel_values, grid_thw=grid_thw) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, grid_thw=grid_thw, - cu_seqlens=cu_seqlens, attention_mask=attention_mask, - position_ids=position_ids, **kwargs, ) @@ -910,8 +894,6 @@ def forward( self, pixel_values: torch.FloatTensor, grid_thw: torch.LongTensor | None = None, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ @@ -920,16 +902,8 @@ def forward( The tensors corresponding to the input images. grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): - The cumulative sequence lengths of each image or video feature. """ - return self.vision_model( - pixel_values=pixel_values, - grid_thw=grid_thw, - cu_seqlens=cu_seqlens, - position_ids=position_ids, - **kwargs, - ) + return self.vision_model(pixel_values=pixel_values, grid_thw=grid_thw, **kwargs) class PaddleOCRVLModelOutputWithPast(Qwen2VLModelOutputWithPast): diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index be4742830e51..d54c35171be5 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -876,11 +876,6 @@ def forward( input_features=None, feature_lens=None, aftercnn_lens=None, - padded_feature=None, - chunk_lengths=None, - valid_indices=None, - pool_indices=None, - cu_seqlens=None, **kwargs: Unpack[TransformersKwargs], ): r""" @@ -888,25 +883,14 @@ def forward( mel length aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`, *optional*): mel length after cnn - padded_feature (`torch.FloatTensor`, *optional*): - Precomputed padded audio chunks (from `chunk_and_pad_features`). - chunk_lengths (`torch.LongTensor`, *optional*): - Precomputed per-chunk lengths (from `chunk_and_pad_features`). - valid_indices (`torch.LongTensor`, *optional*): - Precomputed flat indices of valid post-CNN positions (from `get_valid_indices`). - pool_indices (`torch.LongTensor`, *optional*): - Precomputed pair indices for stride-2 average pooling (from `get_pool_indices`). - cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_audio_cu_seqlens`). """ + padded_feature = kwargs.pop("padded_feature", None) + chunk_lengths = kwargs.pop("chunk_lengths", None) if padded_feature is None: padded_feature, chunk_lengths = chunk_and_pad_features(input_features, feature_lens, self.n_window) - - if valid_indices is None: - valid_indices = get_valid_indices(chunk_lengths) - - if pool_indices is None: - pool_indices = get_pool_indices(feature_lens) + valid_indices = kwargs.pop("valid_indices", None) or get_valid_indices(chunk_lengths) + pool_indices = kwargs.pop("pool_indices", None) or get_pool_indices(feature_lens) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_audio_cu_seqlens(chunk_lengths) # Derive masks from chunk_lengths (traceable arithmetic + arange broadcasting) padded_feature = padded_feature.to(self.conv1.weight.dtype) @@ -922,9 +906,6 @@ def forward( ].unsqueeze(0).to(padded_embed.dtype) hidden_states = torch.index_select(padded_embed.reshape(-1, padded_embed.shape[-1]), 0, valid_indices) - if cu_seqlens is None: - cu_seqlens = get_audio_cu_seqlens(chunk_lengths) - for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, @@ -1285,10 +1266,6 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - window_index: torch.Tensor | None = None, - cu_window_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ @@ -1297,31 +1274,21 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - cu_window_seqlens (`torch.Tensor`, *optional*): - Precomputed window cumulative sequence lengths (from `get_vision_window_index`). - window_index (`torch.Tensor`, *optional*): - Precomputed window reordering index (from `get_vision_window_index`). - position_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + window_index = kwargs.pop("window_index", None) + cu_window_seqlens = kwargs.pop("cu_window_seqlens", None) if window_index is None: window_index, cu_window_seqlens = get_vision_window_index( grid_thw, self.spatial_merge_size, self.window_size, self.patch_size, self.spatial_merge_unit ) + hidden_states = self.patch_embed(hidden_states) + seq_len, _ = hidden_states.size() reverse_indices = torch.argsort(window_index) hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index be82c77a87cc..54d5ac22cef7 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1353,11 +1353,6 @@ def forward( input_features=None, feature_lens=None, aftercnn_lens=None, - padded_feature=None, - chunk_lengths=None, - valid_indices=None, - pool_indices=None, - cu_seqlens=None, **kwargs: Unpack[TransformersKwargs], ): r""" @@ -1365,25 +1360,14 @@ def forward( mel length aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`, *optional*): mel length after cnn - padded_feature (`torch.FloatTensor`, *optional*): - Precomputed padded audio chunks (from `chunk_and_pad_features`). - chunk_lengths (`torch.LongTensor`, *optional*): - Precomputed per-chunk lengths (from `chunk_and_pad_features`). - valid_indices (`torch.LongTensor`, *optional*): - Precomputed flat indices of valid post-CNN positions (from `get_valid_indices`). - pool_indices (`torch.LongTensor`, *optional*): - Precomputed pair indices for stride-2 average pooling (from `get_pool_indices`). - cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_audio_cu_seqlens`). """ + padded_feature = kwargs.pop("padded_feature", None) + chunk_lengths = kwargs.pop("chunk_lengths", None) if padded_feature is None: padded_feature, chunk_lengths = chunk_and_pad_features(input_features, feature_lens, self.n_window) - - if valid_indices is None: - valid_indices = get_valid_indices(chunk_lengths) - - if pool_indices is None: - pool_indices = get_pool_indices(feature_lens) + valid_indices = kwargs.pop("valid_indices", None) or get_valid_indices(chunk_lengths) + pool_indices = kwargs.pop("pool_indices", None) or get_pool_indices(feature_lens) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_audio_cu_seqlens(chunk_lengths) # Derive masks from chunk_lengths (traceable arithmetic + arange broadcasting) padded_feature = padded_feature.to(self.conv1.weight.dtype) @@ -1399,9 +1383,6 @@ def forward( ].unsqueeze(0).to(padded_embed.dtype) hidden_states = torch.index_select(padded_embed.reshape(-1, padded_embed.shape[-1]), 0, valid_indices) - if cu_seqlens is None: - cu_seqlens = get_audio_cu_seqlens(chunk_lengths) - for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, @@ -1610,10 +1591,6 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - window_index: torch.Tensor | None = None, - cu_window_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ @@ -1622,31 +1599,21 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - cu_window_seqlens (`torch.Tensor`, *optional*): - Precomputed window cumulative sequence lengths (from `get_vision_window_index`). - window_index (`torch.Tensor`, *optional*): - Precomputed window reordering index (from `get_vision_window_index`). - position_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + window_index = kwargs.pop("window_index", None) + cu_window_seqlens = kwargs.pop("cu_window_seqlens", None) if window_index is None: window_index, cu_window_seqlens = get_vision_window_index( grid_thw, self.spatial_merge_size, self.window_size, self.patch_size, self.spatial_merge_unit ) + hidden_states = self.patch_embed(hidden_states) + seq_len, _ = hidden_states.size() reverse_indices = torch.argsort(window_index) hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index c8a76f8982fc..c60338fa655a 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -414,10 +414,6 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - window_index: torch.Tensor | None = None, - cu_window_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ @@ -426,31 +422,21 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - window_index (`torch.Tensor`, *optional*): - Precomputed window reordering index (from `get_vision_window_index`). - cu_window_seqlens (`torch.Tensor`, *optional*): - Precomputed window cumulative sequence lengths (from `get_vision_window_index`). - position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + window_index = kwargs.pop("window_index", None) + cu_window_seqlens = kwargs.pop("cu_window_seqlens", None) if window_index is None: window_index, cu_window_seqlens = get_vision_window_index( grid_thw, self.spatial_merge_size, self.window_size, self.patch_size, self.spatial_merge_unit ) + hidden_states = self.patch_embed(hidden_states) + seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 9fcbf66b784f..2adc3e4989fb 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -250,10 +250,6 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - window_index: torch.Tensor | None = None, - cu_window_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ @@ -262,31 +258,21 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - window_index (`torch.Tensor`, *optional*): - Precomputed window reordering index (from `get_vision_window_index`). - cu_window_seqlens (`torch.Tensor`, *optional*): - Precomputed window cumulative sequence lengths (from `get_vision_window_index`). - position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + window_index = kwargs.pop("window_index", None) + cu_window_seqlens = kwargs.pop("cu_window_seqlens", None) if window_index is None: window_index, cu_window_seqlens = get_vision_window_index( grid_thw, self.spatial_merge_size, self.window_size, self.patch_size, self.spatial_merge_unit ) + hidden_states = self.patch_embed(hidden_states) + seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 897d47c874e8..84845eb8528b 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -739,31 +739,20 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: r""" grid_thw (`torch.LongTensor` of shape `(num_images, 3)`): The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values. - cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). """ - hidden_states = self.patch_embed(hidden_states) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rotary_pos_emb(position_ids) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - for blk in self.blocks: hidden_states = blk( hidden_states, diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index baed2fe9b0fc..5dc0779c72d9 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1068,10 +1068,6 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, - bilinear_indices: torch.Tensor | None = None, - bilinear_weights: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: """ @@ -1080,35 +1076,24 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). - bilinear_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). - bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner weights for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - + bilinear_indices = kwargs.pop("bilinear_indices", None) + bilinear_weights = kwargs.pop("bilinear_weights", None) if bilinear_indices is None or bilinear_weights is None: bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + + hidden_states = self.patch_embed(hidden_states) pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index ab7cd4efd664..0f00e3f7800b 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -434,10 +434,6 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, - bilinear_indices: torch.Tensor | None = None, - bilinear_weights: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: """ @@ -446,35 +442,24 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). - bilinear_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). - bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner weights for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - + bilinear_indices = kwargs.pop("bilinear_indices", None) + bilinear_weights = kwargs.pop("bilinear_weights", None) if bilinear_indices is None or bilinear_weights is None: bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + + hidden_states = self.patch_embed(hidden_states) pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index e84062bd43c7..5edbfaed1dfc 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1161,10 +1161,6 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, - bilinear_indices: torch.Tensor | None = None, - bilinear_weights: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: """ @@ -1173,35 +1169,24 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). - bilinear_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). - bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner weights for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - + bilinear_indices = kwargs.pop("bilinear_indices", None) + bilinear_weights = kwargs.pop("bilinear_weights", None) if bilinear_indices is None or bilinear_weights is None: bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + + hidden_states = self.patch_embed(hidden_states) pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 863c92814372..2e16f2f45f91 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -776,32 +776,20 @@ def forward( self, input_features=None, feature_lens=None, - padded_feature=None, - chunk_lengths=None, - valid_indices=None, - cu_seqlens=None, **kwargs: Unpack[TransformersKwargs], ): r""" feature_lens (`torch.LongTensor` of shape `(batch_size,)`): mel length - padded_feature (`torch.FloatTensor`, *optional*): - Precomputed padded audio chunks (from `chunk_and_pad_features`). - chunk_lengths (`torch.LongTensor`, *optional*): - Precomputed per-chunk lengths (from `chunk_and_pad_features`). - valid_indices (`torch.LongTensor`, *optional*): - Precomputed flat indices of valid post-CNN positions (from `get_valid_indices`). - cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_audio_cu_seqlens`). """ + padded_feature = kwargs.pop("padded_feature", None) + chunk_lengths = kwargs.pop("chunk_lengths", None) if padded_feature is None: padded_feature, chunk_lengths = chunk_and_pad_features(input_features, feature_lens, self.n_window) - - if valid_indices is None: - valid_indices = get_valid_indices(chunk_lengths) - - if cu_seqlens is None: - cu_seqlens = get_audio_cu_seqlens(chunk_lengths, feature_lens, self.n_window_infer, self.n_window) + valid_indices = kwargs.pop("valid_indices", None) or get_valid_indices(chunk_lengths) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_audio_cu_seqlens( + chunk_lengths, feature_lens, self.n_window_infer, self.n_window + ) # Add channel dim for Conv2d: (num_chunks, mel_bins, time) -> (num_chunks, 1, mel_bins, time) padded_feature = padded_feature.unsqueeze(1) @@ -1192,10 +1180,6 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, - bilinear_indices: torch.Tensor | None = None, - bilinear_weights: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: """ @@ -1204,35 +1188,24 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). - bilinear_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). - bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner weights for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - + bilinear_indices = kwargs.pop("bilinear_indices", None) + bilinear_weights = kwargs.pop("bilinear_weights", None) if bilinear_indices is None or bilinear_weights is None: bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + + hidden_states = self.patch_embed(hidden_states) pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index 2ffe2dc3fb83..62be586b58ca 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -991,32 +991,20 @@ def forward( self, input_features=None, feature_lens=None, - padded_feature=None, - chunk_lengths=None, - valid_indices=None, - cu_seqlens=None, **kwargs: Unpack[TransformersKwargs], ): r""" feature_lens (`torch.LongTensor` of shape `(batch_size,)`): mel length - padded_feature (`torch.FloatTensor`, *optional*): - Precomputed padded audio chunks (from `chunk_and_pad_features`). - chunk_lengths (`torch.LongTensor`, *optional*): - Precomputed per-chunk lengths (from `chunk_and_pad_features`). - valid_indices (`torch.LongTensor`, *optional*): - Precomputed flat indices of valid post-CNN positions (from `get_valid_indices`). - cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_audio_cu_seqlens`). """ + padded_feature = kwargs.pop("padded_feature", None) + chunk_lengths = kwargs.pop("chunk_lengths", None) if padded_feature is None: padded_feature, chunk_lengths = chunk_and_pad_features(input_features, feature_lens, self.n_window) - - if valid_indices is None: - valid_indices = get_valid_indices(chunk_lengths) - - if cu_seqlens is None: - cu_seqlens = get_audio_cu_seqlens(chunk_lengths, feature_lens, self.n_window_infer, self.n_window) + valid_indices = kwargs.pop("valid_indices", None) or get_valid_indices(chunk_lengths) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_audio_cu_seqlens( + chunk_lengths, feature_lens, self.n_window_infer, self.n_window + ) # Add channel dim for Conv2d: (num_chunks, mel_bins, time) -> (num_chunks, 1, mel_bins, time) padded_feature = padded_feature.unsqueeze(1) diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 517598021632..d70afa6262d2 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -690,10 +690,6 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, - bilinear_indices: torch.Tensor | None = None, - bilinear_weights: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: """ @@ -702,35 +698,24 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). - bilinear_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). - bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner weights for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - + bilinear_indices = kwargs.pop("bilinear_indices", None) + bilinear_weights = kwargs.pop("bilinear_weights", None) if bilinear_indices is None or bilinear_weights is None: bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + + hidden_states = self.patch_embed(hidden_states) pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index d056351c11ed..f72648e26b0b 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -473,10 +473,6 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, - bilinear_indices: torch.Tensor | None = None, - bilinear_weights: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: """ @@ -485,35 +481,24 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). - bilinear_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). - bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner weights for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - + bilinear_indices = kwargs.pop("bilinear_indices", None) + bilinear_weights = kwargs.pop("bilinear_weights", None) if bilinear_indices is None or bilinear_weights is None: bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + + hidden_states = self.patch_embed(hidden_states) pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index f7dc594a395e..7d099afe2e16 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -675,10 +675,6 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, - bilinear_indices: torch.Tensor | None = None, - bilinear_weights: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithDeepstackFeatures: """ @@ -687,35 +683,24 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor` of shape `(total_tokens, 2)`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). - bilinear_indices (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner indices for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). - bilinear_weights (`torch.Tensor` of shape `(4, total_thw)`, *optional*): - Bilinear corner weights for position embedding interpolation (from `get_vision_bilinear_indices_and_weights`). Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - + bilinear_indices = kwargs.pop("bilinear_indices", None) + bilinear_weights = kwargs.pop("bilinear_weights", None) if bilinear_indices is None or bilinear_weights is None: bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( grid_thw, self.num_grid_per_side, self.config.spatial_merge_size ) + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + + hidden_states = self.patch_embed(hidden_states) pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - rotary_pos_emb = self.rotary_pos_emb(position_ids) - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index be8bea0dc2ad..2996e354d829 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -414,8 +414,6 @@ def forward( pixel_values: torch.Tensor, grid_thw: torch.Tensor, merge_sizes: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutput: r""" @@ -423,24 +421,15 @@ def forward( The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values. merge_sizes (`torch.Tensor` of shape `(num_images_or_videos,)`): The spatial downsampling ratio of each image or video feature. - cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). """ + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, merge_sizes) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) hidden_states = self.embeddings(pixel_values.type(self.dtype)) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, merge_sizes) - rotary_pos_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - encoder_outputs: BaseModelOutput = self.encoder( hidden_states, cu_seqlens=cu_seqlens, diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index b90062fbd362..9d970b8c81f2 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -376,8 +376,6 @@ def forward( pixel_values: torch.Tensor, grid_thw: torch.Tensor, merge_sizes: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutput: r""" @@ -385,24 +383,15 @@ def forward( The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values. merge_sizes (`torch.Tensor` of shape `(num_images_or_videos,)`): The spatial downsampling ratio of each image or video feature. - cu_seqlens (`torch.IntTensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). """ + position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, merge_sizes) + cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) hidden_states = self.embeddings(pixel_values.type(self.dtype)) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, merge_sizes) - rotary_pos_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) - encoder_outputs: BaseModelOutput = self.encoder( hidden_states, cu_seqlens=cu_seqlens, From 4838e17991bc5a0c74e8f60cfe19eff434fa0668 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 28 Apr 2026 12:06:55 +0200 Subject: [PATCH 51/56] revert some trailing commas --- .../modeling_ernie4_5_vl_moe.py | 5 +---- .../modular_ernie4_5_vl_moe.py | 5 +---- .../models/glm4v/modeling_glm4v.py | 5 +---- .../models/glm4v/modular_glm4v.py | 5 +---- .../models/glm4v_moe/modeling_glm4v_moe.py | 5 +---- .../models/glm_image/modeling_glm_image.py | 5 +---- .../models/glm_image/modular_glm_image.py | 5 +---- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 11 ++-------- .../qwen2_5_omni/modular_qwen2_5_omni.py | 11 ++-------- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 5 +---- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 5 +---- .../models/qwen2_vl/modeling_qwen2_vl.py | 5 +---- .../models/qwen3_5/modeling_qwen3_5.py | 7 +----- .../models/qwen3_5/modular_qwen3_5.py | 22 ++++--------------- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 7 +----- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 12 ++-------- .../qwen3_omni_moe/modular_qwen3_omni_moe.py | 7 +----- .../models/qwen3_vl/modeling_qwen3_vl.py | 5 +---- .../models/qwen3_vl/modular_qwen3_vl.py | 5 +---- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 5 +---- 20 files changed, 26 insertions(+), 116 deletions(-) diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index fce3ce797927..c999a80a6903 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -911,10 +911,7 @@ def rot_pos_emb(self, grid_thw): @merge_with_config_defaults @capture_outputs def forward( - self, - hidden_states: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs: Unpack[TransformersKwargs], + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple | BaseModelOutputWithPooling: r""" grid_thw (`torch.LongTensor` of shape `(num_images, 3)`): diff --git a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py index 3b2d7721a3d9..d1d705c1eeb0 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py @@ -699,10 +699,7 @@ def get_device(self): @merge_with_config_defaults @capture_outputs def forward( - self, - hidden_states: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs: Unpack[TransformersKwargs], + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple | BaseModelOutputWithPooling: position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 3c4266995c15..7b1543cefac0 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -745,10 +745,7 @@ def rot_pos_emb(self, grid_thw): @capture_outputs @auto_docstring def forward( - self, - hidden_states: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs: Unpack[TransformersKwargs], + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple | BaseModelOutputWithPooling: r""" hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 9dc601aecc94..2f85b3e46d34 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -624,10 +624,7 @@ def rot_pos_emb(self, grid_thw): @capture_outputs @auto_docstring def forward( - self, - hidden_states: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs: Unpack[TransformersKwargs], + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple | BaseModelOutputWithPooling: r""" hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index b60b8870945f..c42d1a41ab9f 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -810,10 +810,7 @@ def rot_pos_emb(self, grid_thw): @capture_outputs @auto_docstring def forward( - self, - hidden_states: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs: Unpack[TransformersKwargs], + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple | BaseModelOutputWithPooling: r""" hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py index 64ef7d636d5d..396db8c5ad72 100644 --- a/src/transformers/models/glm_image/modeling_glm_image.py +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -602,10 +602,7 @@ def rot_pos_emb(self, grid_thw): @capture_outputs @auto_docstring def forward( - self, - pixel_values: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs: Unpack[TransformersKwargs], + self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple | BaseModelOutputWithPooling: r""" pixel_values (`torch.Tensor` of shape `(total_patches, num_channels * patch_size * patch_size)`): diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index ba7eb08aa08e..dc99a0fea98a 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -442,10 +442,7 @@ def rot_pos_emb(self, grid_thw): @capture_outputs @auto_docstring def forward( - self, - pixel_values: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs: Unpack[TransformersKwargs], + self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple | BaseModelOutputWithPooling: r""" pixel_values (`torch.Tensor` of shape `(total_patches, num_channels * patch_size * patch_size)`): diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index d54c35171be5..7c283e5b6641 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -872,11 +872,7 @@ def set_input_embeddings(self, value: nn.Module): @capture_outputs(tie_last_hidden_states=False) @auto_docstring def forward( - self, - input_features=None, - feature_lens=None, - aftercnn_lens=None, - **kwargs: Unpack[TransformersKwargs], + self, input_features=None, feature_lens=None, aftercnn_lens=None, **kwargs: Unpack[TransformersKwargs] ): r""" feature_lens (`torch.LongTensor` of shape `(batch_size,)`): @@ -1263,10 +1259,7 @@ def get_window_index(self, grid_thw): @merge_with_config_defaults @capture_outputs def forward( - self, - hidden_states: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs: Unpack[TransformersKwargs], + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple | BaseModelOutputWithPooling: """ Args: diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 54d5ac22cef7..021fbb745276 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1349,11 +1349,7 @@ def set_input_embeddings(self, value: nn.Module): @capture_outputs(tie_last_hidden_states=False) @auto_docstring def forward( - self, - input_features=None, - feature_lens=None, - aftercnn_lens=None, - **kwargs: Unpack[TransformersKwargs], + self, input_features=None, feature_lens=None, aftercnn_lens=None, **kwargs: Unpack[TransformersKwargs] ): r""" feature_lens (`torch.LongTensor` of shape `(batch_size,)`): @@ -1588,10 +1584,7 @@ def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> @merge_with_config_defaults @capture_outputs def forward( - self, - hidden_states: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs: Unpack[TransformersKwargs], + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple | BaseModelOutputWithPooling: """ Args: diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index c60338fa655a..0ecd7f65fbca 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -411,10 +411,7 @@ def get_window_index(self, grid_thw): @merge_with_config_defaults @capture_outputs def forward( - self, - hidden_states: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs: Unpack[TransformersKwargs], + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple | BaseModelOutputWithPooling: """ Args: diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 2adc3e4989fb..b1f088ff5a18 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -247,10 +247,7 @@ def get_window_index(self, grid_thw): @merge_with_config_defaults @capture_outputs def forward( - self, - hidden_states: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs: Unpack[TransformersKwargs], + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple | BaseModelOutputWithPooling: """ Args: diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 84845eb8528b..bbbd3ad4bd8b 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -736,10 +736,7 @@ def rot_pos_emb(self, grid_thw): @capture_outputs @auto_docstring def forward( - self, - hidden_states: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs: Unpack[TransformersKwargs], + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] ) -> torch.Tensor: r""" grid_thw (`torch.LongTensor` of shape `(num_images, 3)`): diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index 5dc0779c72d9..2a46598472b5 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1064,12 +1064,7 @@ def fast_pos_embed_interpolate(self, grid_thw): @merge_with_config_defaults @capture_outputs - def forward( - self, - hidden_states: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs, - ) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index 0f00e3f7800b..e6d1595df969 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -430,12 +430,7 @@ def __init__(self, config, *inputs, **kwargs) -> None: @merge_with_config_defaults @capture_outputs - def forward( - self, - hidden_states: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs, - ) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): @@ -562,10 +557,7 @@ def forward( class Qwen3_5Model(Qwen3VLModel): _no_split_modules = ["Qwen3_5DecoderLayer", "Qwen3_5VisionBlock"] - def get_video_features( - self, - **super_kwargs, - ) -> tuple | BaseModelOutputWithPooling: + def get_video_features(self, **super_kwargs) -> tuple | BaseModelOutputWithPooling: # Same implementation as for images return super().get_video_features(**super_kwargs) @@ -679,16 +671,10 @@ class Qwen3_5ForSequenceClassification(GenericForSequenceClassification, Qwen3_5 class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration): - def get_video_features( - self, - **super_kwargs, - ) -> tuple | BaseModelOutputWithPooling: + def get_video_features(self, **super_kwargs) -> tuple | BaseModelOutputWithPooling: return super().get_video_features(**super_kwargs) - def get_image_features( - self, - **super_kwargs, - ) -> tuple | BaseModelOutputWithPooling: + def get_image_features(self, **super_kwargs) -> tuple | BaseModelOutputWithPooling: return super().get_image_features(**super_kwargs) diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 5edbfaed1dfc..8613d079238a 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1157,12 +1157,7 @@ def fast_pos_embed_interpolate(self, grid_thw): @merge_with_config_defaults @capture_outputs - def forward( - self, - hidden_states: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs, - ) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 2e16f2f45f91..640413b2e10f 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -772,12 +772,7 @@ def set_input_embeddings(self, value): @merge_with_config_defaults @capture_outputs(tie_last_hidden_states=False) @auto_docstring - def forward( - self, - input_features=None, - feature_lens=None, - **kwargs: Unpack[TransformersKwargs], - ): + def forward(self, input_features=None, feature_lens=None, **kwargs: Unpack[TransformersKwargs]): r""" feature_lens (`torch.LongTensor` of shape `(batch_size,)`): mel length @@ -1177,10 +1172,7 @@ def fast_pos_embed_interpolate(self, grid_thw): @merge_with_config_defaults @capture_outputs def forward( - self, - hidden_states: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs: Unpack[TransformersKwargs], + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple | BaseModelOutputWithDeepstackFeatures: """ Args: diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index 62be586b58ca..89a54c991ae2 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -987,12 +987,7 @@ def set_input_embeddings(self, value): @merge_with_config_defaults @capture_outputs(tie_last_hidden_states=False) @auto_docstring - def forward( - self, - input_features=None, - feature_lens=None, - **kwargs: Unpack[TransformersKwargs], - ): + def forward(self, input_features=None, feature_lens=None, **kwargs: Unpack[TransformersKwargs]): r""" feature_lens (`torch.LongTensor` of shape `(batch_size,)`): mel length diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index d70afa6262d2..ec6cdd57d5b1 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -687,10 +687,7 @@ def fast_pos_embed_interpolate(self, grid_thw): @merge_with_config_defaults @capture_outputs def forward( - self, - hidden_states: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs: Unpack[TransformersKwargs], + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple | BaseModelOutputWithDeepstackFeatures: """ Args: diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index f72648e26b0b..626732ea0fcb 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -470,10 +470,7 @@ def fast_pos_embed_interpolate(self, grid_thw): @merge_with_config_defaults @capture_outputs def forward( - self, - hidden_states: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs: Unpack[TransformersKwargs], + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple | BaseModelOutputWithDeepstackFeatures: """ Args: diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 7d099afe2e16..1531ed2a7563 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -672,10 +672,7 @@ def fast_pos_embed_interpolate(self, grid_thw): @merge_with_config_defaults @capture_outputs def forward( - self, - hidden_states: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs: Unpack[TransformersKwargs], + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple | BaseModelOutputWithDeepstackFeatures: """ Args: From 4246a72dbaac243d44ed1751bcb537dcf05d987c Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Tue, 28 Apr 2026 13:11:28 +0200 Subject: [PATCH 52/56] fix --- .../modeling_ernie4_5_vl_moe.py | 4 +- .../modular_ernie4_5_vl_moe.py | 4 +- .../models/glm4v/modeling_glm4v.py | 4 +- .../models/glm4v/modular_glm4v.py | 4 +- .../models/glm4v_moe/modeling_glm4v_moe.py | 4 +- .../models/glm_image/modeling_glm_image.py | 4 +- .../models/glm_image/modular_glm_image.py | 4 +- .../paddleocr_vl/modeling_paddleocr_vl.py | 4 +- .../paddleocr_vl/modular_paddleocr_vl.py | 4 +- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 62 ++++++++++++------- .../qwen2_5_omni/modular_qwen2_5_omni.py | 62 ++++++++++++------- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 18 +++--- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 18 +++--- .../models/qwen2_vl/modeling_qwen2_vl.py | 4 +- .../models/qwen3_5/modeling_qwen3_5.py | 16 ++--- .../models/qwen3_5/modular_qwen3_5.py | 16 ++--- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 16 ++--- .../qwen3_omni_moe/modeling_qwen3_omni_moe.py | 58 ++++++++++------- .../qwen3_omni_moe/modular_qwen3_omni_moe.py | 42 +++++++++---- .../models/qwen3_vl/modeling_qwen3_vl.py | 16 ++--- .../models/qwen3_vl/modular_qwen3_vl.py | 16 ++--- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 16 ++--- .../video_llama_3/modeling_video_llama_3.py | 4 +- .../video_llama_3/modular_video_llama_3.py | 4 +- src/transformers/vision_utils.py | 46 +++++++++++--- 25 files changed, 275 insertions(+), 175 deletions(-) diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index c999a80a6903..c15bca5033a4 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -917,8 +917,8 @@ def forward( grid_thw (`torch.LongTensor` of shape `(num_images, 3)`): The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values. """ - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rotary_pos_emb(position_ids) diff --git a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py index d1d705c1eeb0..a176dae49f45 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py @@ -701,8 +701,8 @@ def get_device(self): def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple | BaseModelOutputWithPooling: - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rotary_pos_emb(position_ids) diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 7b1543cefac0..bda42b8ec99d 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -756,8 +756,8 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.patch_embed(hidden_states) hidden_states = self.post_conv_layernorm(hidden_states) diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 2f85b3e46d34..f7f41ceb00b0 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -635,8 +635,8 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.patch_embed(hidden_states) hidden_states = self.post_conv_layernorm(hidden_states) diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index c42d1a41ab9f..405a5de64833 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -821,8 +821,8 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.patch_embed(hidden_states) hidden_states = self.post_conv_layernorm(hidden_states) diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py index 396db8c5ad72..601f4f97c583 100644 --- a/src/transformers/models/glm_image/modeling_glm_image.py +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -613,8 +613,8 @@ def forward( Returns: `torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states. """ - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.patch_embed(pixel_values) seqlens = cu_seqlens[1:] - cu_seqlens[:-1] diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index dc99a0fea98a..b71be80c71ea 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -453,8 +453,8 @@ def forward( Returns: `torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states. """ - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.patch_embed(pixel_values) seqlens = cu_seqlens[1:] - cu_seqlens[:-1] diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index b5e6f8e2a228..d54259fe2902 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -840,8 +840,8 @@ def forward( """ # Use merge_size=1: PaddleOCR merges patches in the projector (after the encoder), # unlike Qwen which merges inside the encoder, so rotary positions here are simple (row, col). - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, 1) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + position_ids = get_vision_position_ids(grid_thw, 1, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = inputs_embeds attention_mask = create_bidirectional_mask( diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index a5336520c341..16051aa7305a 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -796,8 +796,8 @@ def forward( """ # Use merge_size=1: PaddleOCR merges patches in the projector (after the encoder), # unlike Qwen which merges inside the encoder, so rotary positions here are simple (row, col). - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, 1) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + position_ids = get_vision_position_ids(grid_thw, 1, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = inputs_embeds attention_mask = create_bidirectional_mask( diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 7c283e5b6641..a1fb96847d19 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -744,9 +744,9 @@ def forward(self, seqlen: int): def chunk_and_pad_features( - input_features: torch.Tensor, feature_lens: torch.Tensor, n_window: int + input_features: torch.Tensor, feature_lens: torch.Tensor, n_window: int, *, kwargs: dict | None = None ) -> tuple[torch.Tensor, torch.Tensor]: - """Split audio features into fixed-size chunks and pad to uniform length. + """Split audio features into fixed-size chunks and pad to uniform length, or pop precomputed pair from ``kwargs``. Each audio sample is split into chunks of ``n_window * 2`` frames (the last chunk may be shorter), then all chunks are right-padded to the longest chunk. @@ -755,11 +755,17 @@ def chunk_and_pad_features( input_features: ``(feature_dim, total_frames)`` concatenated audio features. feature_lens: ``(batch_size,)`` per-sample frame counts. n_window: half the target chunk size in frames. + kwargs: optional caller kwargs — if it contains both ``"padded_feature"`` and ``"chunk_lengths"`` they are popped and returned. Returns: ``padded_feature``: ``(num_chunks, feature_dim, max_chunk_len)`` padded chunks. ``chunk_lengths``: ``(num_chunks,)`` actual length of each chunk before padding. """ + if kwargs is not None: + padded_feature = kwargs.pop("padded_feature", None) + chunk_lengths = kwargs.pop("chunk_lengths", None) + if padded_feature is not None and chunk_lengths is not None: + return padded_feature, chunk_lengths chunk_num = torch.ceil(feature_lens / (n_window * 2)).long() chunk_lengths = torch.full((chunk_num.sum(),), n_window * 2, dtype=torch.long, device=feature_lens.device) tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] @@ -770,46 +776,55 @@ def chunk_and_pad_features( return padded_feature, chunk_lengths -def get_audio_cu_seqlens(chunk_lengths: torch.Tensor) -> torch.Tensor: - """Compute cumulative sequence lengths for audio attention from chunk lengths. +def get_audio_cu_seqlens(chunk_lengths: torch.Tensor, *, kwargs: dict | None = None) -> torch.Tensor: + """Compute cumulative sequence lengths for audio attention, or pop ``"cu_seqlens"`` from ``kwargs`` if precomputed. Applies one stride-2 convolution length reduction, then returns cumulative boundaries for flash-attention-style sequence packing. Args: chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + kwargs: optional caller kwargs — if it contains ``"cu_seqlens"`` it is popped and returned. Returns: ``(num_chunks + 1,)`` int32 cumulative sequence boundaries. """ + if kwargs is not None and (cu_seqlens := kwargs.pop("cu_seqlens", None)) is not None: + return cu_seqlens after_conv1 = (chunk_lengths - 1) // 2 + 1 return F.pad(after_conv1.cumsum(0), (1, 0), value=0).to(torch.int32) -def get_valid_indices(chunk_lengths: torch.Tensor) -> torch.Tensor: - """Compute flat indices of valid (non-padding) positions after one stride-2 conv. +def get_valid_indices(chunk_lengths: torch.Tensor, *, kwargs: dict | None = None) -> torch.Tensor: + """Compute flat indices of valid (non-padding) positions after one stride-2 conv, or pop ``"valid_indices"`` from ``kwargs`` if precomputed. Args: chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + kwargs: optional caller kwargs — if it contains ``"valid_indices"`` it is popped and returned. Returns: ``(total_valid,)`` flat indices into the ``(num_chunks * max_len_after_conv)`` grid. """ + if kwargs is not None and (valid_indices := kwargs.pop("valid_indices", None)) is not None: + return valid_indices after_conv1 = (chunk_lengths - 1) // 2 + 1 max_len = after_conv1.max().item() mask = torch.arange(max_len, device=chunk_lengths.device) < after_conv1.unsqueeze(1) return mask.flatten().nonzero().squeeze(-1) -def get_pool_indices(feature_lens: torch.Tensor) -> torch.Tensor: - """Compute indices for post-encoder stride-2 average pooling. +def get_pool_indices(feature_lens: torch.Tensor, *, kwargs: dict | None = None) -> torch.Tensor: + """Compute indices for post-encoder stride-2 average pooling, or pop ``"pool_indices"`` from ``kwargs`` if precomputed. Args: feature_lens: ``(batch_size,)`` mel spectrogram lengths. + kwargs: optional caller kwargs — if it contains ``"pool_indices"`` it is popped and returned. Returns: ``(total_pooled,)`` flat index of first element of each stride-2 pair. """ + if kwargs is not None and (pool_indices := kwargs.pop("pool_indices", None)) is not None: + return pool_indices after_conv1 = (feature_lens - 1) // 2 + 1 num_pooled = (after_conv1 - 2) // 2 + 1 offsets = F.pad(after_conv1[:-1].cumsum(0), (1, 0), value=0) @@ -880,13 +895,12 @@ def forward( aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`, *optional*): mel length after cnn """ - padded_feature = kwargs.pop("padded_feature", None) - chunk_lengths = kwargs.pop("chunk_lengths", None) - if padded_feature is None: - padded_feature, chunk_lengths = chunk_and_pad_features(input_features, feature_lens, self.n_window) - valid_indices = kwargs.pop("valid_indices", None) or get_valid_indices(chunk_lengths) - pool_indices = kwargs.pop("pool_indices", None) or get_pool_indices(feature_lens) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_audio_cu_seqlens(chunk_lengths) + padded_feature, chunk_lengths = chunk_and_pad_features( + input_features, feature_lens, self.n_window, kwargs=kwargs + ) + valid_indices = get_valid_indices(chunk_lengths, kwargs=kwargs) + pool_indices = get_pool_indices(feature_lens, kwargs=kwargs) + cu_seqlens = get_audio_cu_seqlens(chunk_lengths, kwargs=kwargs) # Derive masks from chunk_lengths (traceable arithmetic + arange broadcasting) padded_feature = padded_feature.to(self.conv1.weight.dtype) @@ -1271,14 +1285,16 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) - window_index = kwargs.pop("window_index", None) - cu_window_seqlens = kwargs.pop("cu_window_seqlens", None) - if window_index is None: - window_index, cu_window_seqlens = get_vision_window_index( - grid_thw, self.spatial_merge_size, self.window_size, self.patch_size, self.spatial_merge_unit - ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + window_index, cu_window_seqlens = get_vision_window_index( + grid_thw, + self.spatial_merge_size, + self.window_size, + self.patch_size, + self.spatial_merge_unit, + kwargs=kwargs, + ) hidden_states = self.patch_embed(hidden_states) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 021fbb745276..6d885813acef 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -69,9 +69,9 @@ def chunk_and_pad_features( - input_features: torch.Tensor, feature_lens: torch.Tensor, n_window: int + input_features: torch.Tensor, feature_lens: torch.Tensor, n_window: int, *, kwargs: dict | None = None ) -> tuple[torch.Tensor, torch.Tensor]: - """Split audio features into fixed-size chunks and pad to uniform length. + """Split audio features into fixed-size chunks and pad to uniform length, or pop precomputed pair from ``kwargs``. Each audio sample is split into chunks of ``n_window * 2`` frames (the last chunk may be shorter), then all chunks are right-padded to the longest chunk. @@ -80,11 +80,17 @@ def chunk_and_pad_features( input_features: ``(feature_dim, total_frames)`` concatenated audio features. feature_lens: ``(batch_size,)`` per-sample frame counts. n_window: half the target chunk size in frames. + kwargs: optional caller kwargs — if it contains both ``"padded_feature"`` and ``"chunk_lengths"`` they are popped and returned. Returns: ``padded_feature``: ``(num_chunks, feature_dim, max_chunk_len)`` padded chunks. ``chunk_lengths``: ``(num_chunks,)`` actual length of each chunk before padding. """ + if kwargs is not None: + padded_feature = kwargs.pop("padded_feature", None) + chunk_lengths = kwargs.pop("chunk_lengths", None) + if padded_feature is not None and chunk_lengths is not None: + return padded_feature, chunk_lengths chunk_num = torch.ceil(feature_lens / (n_window * 2)).long() chunk_lengths = torch.full((chunk_num.sum(),), n_window * 2, dtype=torch.long, device=feature_lens.device) tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] @@ -95,46 +101,55 @@ def chunk_and_pad_features( return padded_feature, chunk_lengths -def get_audio_cu_seqlens(chunk_lengths: torch.Tensor) -> torch.Tensor: - """Compute cumulative sequence lengths for audio attention from chunk lengths. +def get_audio_cu_seqlens(chunk_lengths: torch.Tensor, *, kwargs: dict | None = None) -> torch.Tensor: + """Compute cumulative sequence lengths for audio attention, or pop ``"cu_seqlens"`` from ``kwargs`` if precomputed. Applies one stride-2 convolution length reduction, then returns cumulative boundaries for flash-attention-style sequence packing. Args: chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + kwargs: optional caller kwargs — if it contains ``"cu_seqlens"`` it is popped and returned. Returns: ``(num_chunks + 1,)`` int32 cumulative sequence boundaries. """ + if kwargs is not None and (cu_seqlens := kwargs.pop("cu_seqlens", None)) is not None: + return cu_seqlens after_conv1 = (chunk_lengths - 1) // 2 + 1 return F.pad(after_conv1.cumsum(0), (1, 0), value=0).to(torch.int32) -def get_valid_indices(chunk_lengths: torch.Tensor) -> torch.Tensor: - """Compute flat indices of valid (non-padding) positions after one stride-2 conv. +def get_valid_indices(chunk_lengths: torch.Tensor, *, kwargs: dict | None = None) -> torch.Tensor: + """Compute flat indices of valid (non-padding) positions after one stride-2 conv, or pop ``"valid_indices"`` from ``kwargs`` if precomputed. Args: chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + kwargs: optional caller kwargs — if it contains ``"valid_indices"`` it is popped and returned. Returns: ``(total_valid,)`` flat indices into the ``(num_chunks * max_len_after_conv)`` grid. """ + if kwargs is not None and (valid_indices := kwargs.pop("valid_indices", None)) is not None: + return valid_indices after_conv1 = (chunk_lengths - 1) // 2 + 1 max_len = after_conv1.max().item() mask = torch.arange(max_len, device=chunk_lengths.device) < after_conv1.unsqueeze(1) return mask.flatten().nonzero().squeeze(-1) -def get_pool_indices(feature_lens: torch.Tensor) -> torch.Tensor: - """Compute indices for post-encoder stride-2 average pooling. +def get_pool_indices(feature_lens: torch.Tensor, *, kwargs: dict | None = None) -> torch.Tensor: + """Compute indices for post-encoder stride-2 average pooling, or pop ``"pool_indices"`` from ``kwargs`` if precomputed. Args: feature_lens: ``(batch_size,)`` mel spectrogram lengths. + kwargs: optional caller kwargs — if it contains ``"pool_indices"`` it is popped and returned. Returns: ``(total_pooled,)`` flat index of first element of each stride-2 pair. """ + if kwargs is not None and (pool_indices := kwargs.pop("pool_indices", None)) is not None: + return pool_indices after_conv1 = (feature_lens - 1) // 2 + 1 num_pooled = (after_conv1 - 2) // 2 + 1 offsets = F.pad(after_conv1[:-1].cumsum(0), (1, 0), value=0) @@ -1357,13 +1372,12 @@ def forward( aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`, *optional*): mel length after cnn """ - padded_feature = kwargs.pop("padded_feature", None) - chunk_lengths = kwargs.pop("chunk_lengths", None) - if padded_feature is None: - padded_feature, chunk_lengths = chunk_and_pad_features(input_features, feature_lens, self.n_window) - valid_indices = kwargs.pop("valid_indices", None) or get_valid_indices(chunk_lengths) - pool_indices = kwargs.pop("pool_indices", None) or get_pool_indices(feature_lens) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_audio_cu_seqlens(chunk_lengths) + padded_feature, chunk_lengths = chunk_and_pad_features( + input_features, feature_lens, self.n_window, kwargs=kwargs + ) + valid_indices = get_valid_indices(chunk_lengths, kwargs=kwargs) + pool_indices = get_pool_indices(feature_lens, kwargs=kwargs) + cu_seqlens = get_audio_cu_seqlens(chunk_lengths, kwargs=kwargs) # Derive masks from chunk_lengths (traceable arithmetic + arange broadcasting) padded_feature = padded_feature.to(self.conv1.weight.dtype) @@ -1596,14 +1610,16 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) - window_index = kwargs.pop("window_index", None) - cu_window_seqlens = kwargs.pop("cu_window_seqlens", None) - if window_index is None: - window_index, cu_window_seqlens = get_vision_window_index( - grid_thw, self.spatial_merge_size, self.window_size, self.patch_size, self.spatial_merge_unit - ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + window_index, cu_window_seqlens = get_vision_window_index( + grid_thw, + self.spatial_merge_size, + self.window_size, + self.patch_size, + self.spatial_merge_unit, + kwargs=kwargs, + ) hidden_states = self.patch_embed(hidden_states) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 0ecd7f65fbca..ea8702b69e31 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -423,14 +423,16 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) - window_index = kwargs.pop("window_index", None) - cu_window_seqlens = kwargs.pop("cu_window_seqlens", None) - if window_index is None: - window_index, cu_window_seqlens = get_vision_window_index( - grid_thw, self.spatial_merge_size, self.window_size, self.patch_size, self.spatial_merge_unit - ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + window_index, cu_window_seqlens = get_vision_window_index( + grid_thw, + self.spatial_merge_size, + self.window_size, + self.patch_size, + self.spatial_merge_unit, + kwargs=kwargs, + ) hidden_states = self.patch_embed(hidden_states) diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index b1f088ff5a18..6a3dd95ae216 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -259,14 +259,16 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) - window_index = kwargs.pop("window_index", None) - cu_window_seqlens = kwargs.pop("cu_window_seqlens", None) - if window_index is None: - window_index, cu_window_seqlens = get_vision_window_index( - grid_thw, self.spatial_merge_size, self.window_size, self.patch_size, self.spatial_merge_unit - ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + window_index, cu_window_seqlens = get_vision_window_index( + grid_thw, + self.spatial_merge_size, + self.window_size, + self.patch_size, + self.spatial_merge_unit, + kwargs=kwargs, + ) hidden_states = self.patch_embed(hidden_states) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index bbbd3ad4bd8b..9c1c69d97248 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -742,8 +742,8 @@ def forward( grid_thw (`torch.LongTensor` of shape `(num_images, 3)`): The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values. """ - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rotary_pos_emb(position_ids) diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index 2a46598472b5..740a3bb6c938 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -1075,14 +1075,14 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) Returns: `torch.Tensor`: hidden_states. """ - bilinear_indices = kwargs.pop("bilinear_indices", None) - bilinear_weights = kwargs.pop("bilinear_weights", None) - if bilinear_indices is None or bilinear_weights is None: - bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( - grid_thw, self.num_grid_per_side, self.config.spatial_merge_size - ) - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, + self.num_grid_per_side, + self.config.spatial_merge_size, + kwargs=kwargs, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.patch_embed(hidden_states) pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index e6d1595df969..6226893a1b2c 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -441,14 +441,14 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) Returns: `torch.Tensor`: hidden_states. """ - bilinear_indices = kwargs.pop("bilinear_indices", None) - bilinear_weights = kwargs.pop("bilinear_weights", None) - if bilinear_indices is None or bilinear_weights is None: - bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( - grid_thw, self.num_grid_per_side, self.config.spatial_merge_size - ) - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, + self.num_grid_per_side, + self.config.spatial_merge_size, + kwargs=kwargs, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.patch_embed(hidden_states) pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 8613d079238a..d5c8a4d688b7 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -1168,14 +1168,14 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) Returns: `torch.Tensor`: hidden_states. """ - bilinear_indices = kwargs.pop("bilinear_indices", None) - bilinear_weights = kwargs.pop("bilinear_weights", None) - if bilinear_indices is None or bilinear_weights is None: - bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( - grid_thw, self.num_grid_per_side, self.config.spatial_merge_size - ) - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, + self.num_grid_per_side, + self.config.spatial_merge_size, + kwargs=kwargs, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.patch_embed(hidden_states) pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 640413b2e10f..f16aec3e3f21 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -640,9 +640,9 @@ def forward( def chunk_and_pad_features( - input_features: torch.Tensor, feature_lens: torch.Tensor, n_window: int + input_features: torch.Tensor, feature_lens: torch.Tensor, n_window: int, *, kwargs: dict | None = None ) -> tuple[torch.Tensor, torch.Tensor]: - """Split audio features into fixed-size chunks and pad to uniform length. + """Split audio features into fixed-size chunks and pad to uniform length, or pop precomputed pair from ``kwargs``. Each audio sample is split into chunks of ``n_window * 2`` frames (the last chunk may be shorter), then all chunks are right-padded to the longest chunk. @@ -651,11 +651,17 @@ def chunk_and_pad_features( input_features: ``(feature_dim, total_frames)`` concatenated audio features. feature_lens: ``(batch_size,)`` per-sample frame counts. n_window: half the target chunk size in frames. + kwargs: optional caller kwargs — if it contains both ``"padded_feature"`` and ``"chunk_lengths"`` they are popped and returned. Returns: ``padded_feature``: ``(num_chunks, feature_dim, max_chunk_len)`` padded chunks. ``chunk_lengths``: ``(num_chunks,)`` actual length of each chunk before padding. """ + if kwargs is not None: + padded_feature = kwargs.pop("padded_feature", None) + chunk_lengths = kwargs.pop("chunk_lengths", None) + if padded_feature is not None and chunk_lengths is not None: + return padded_feature, chunk_lengths chunk_num = torch.ceil(feature_lens / (n_window * 2)).long() chunk_lengths = torch.full((chunk_num.sum(),), n_window * 2, dtype=torch.long, device=feature_lens.device) tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] @@ -666,15 +672,18 @@ def chunk_and_pad_features( return padded_feature, chunk_lengths -def get_valid_indices(chunk_lengths: torch.Tensor) -> torch.Tensor: - """Compute flat indices of valid (non-padding) positions after CNN extraction. +def get_valid_indices(chunk_lengths: torch.Tensor, *, kwargs: dict | None = None) -> torch.Tensor: + """Compute flat indices of valid (non-padding) positions after CNN extraction, or pop ``"valid_indices"`` from ``kwargs`` if precomputed. Args: chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + kwargs: optional caller kwargs — if it contains ``"valid_indices"`` it is popped and returned. Returns: ``(total_valid,)`` flat indices into the ``(num_chunks * max_len_after_cnn)`` grid. """ + if kwargs is not None and (valid_indices := kwargs.pop("valid_indices", None)) is not None: + return valid_indices feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) max_len_after_cnn = feature_lens_after_cnn.max().item() mask = torch.arange(max_len_after_cnn, device=chunk_lengths.device) < feature_lens_after_cnn.unsqueeze(1) @@ -682,9 +691,14 @@ def get_valid_indices(chunk_lengths: torch.Tensor) -> torch.Tensor: def get_audio_cu_seqlens( - chunk_lengths: torch.Tensor, feature_lens: torch.Tensor, n_window_infer: int, n_window: int + chunk_lengths: torch.Tensor, + feature_lens: torch.Tensor, + n_window_infer: int, + n_window: int, + *, + kwargs: dict | None = None, ) -> torch.Tensor: - """Compute cumulative sequence lengths for audio attention windowing. + """Compute cumulative sequence lengths for audio attention windowing, or pop ``"cu_seqlens"`` from ``kwargs`` if precomputed. Splits each sample's post-CNN features into inference windows and returns cumulative boundaries for flash-attention-style sequence packing. @@ -694,10 +708,13 @@ def get_audio_cu_seqlens( feature_lens: ``(batch_size,)`` per-sample frame counts. n_window_infer: inference window size (in raw frames). n_window: half the chunk size (in raw frames). + kwargs: optional caller kwargs — if it contains ``"cu_seqlens"`` it is popped and returned. Returns: ``(num_windows + 1,)`` int32 cumulative sequence boundaries. """ + if kwargs is not None and (cu_seqlens := kwargs.pop("cu_seqlens", None)) is not None: + return cu_seqlens aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) max_len_after_cnn = feature_lens_after_cnn.max().item() @@ -777,13 +794,12 @@ def forward(self, input_features=None, feature_lens=None, **kwargs: Unpack[Trans feature_lens (`torch.LongTensor` of shape `(batch_size,)`): mel length """ - padded_feature = kwargs.pop("padded_feature", None) - chunk_lengths = kwargs.pop("chunk_lengths", None) - if padded_feature is None: - padded_feature, chunk_lengths = chunk_and_pad_features(input_features, feature_lens, self.n_window) - valid_indices = kwargs.pop("valid_indices", None) or get_valid_indices(chunk_lengths) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_audio_cu_seqlens( - chunk_lengths, feature_lens, self.n_window_infer, self.n_window + padded_feature, chunk_lengths = chunk_and_pad_features( + input_features, feature_lens, self.n_window, kwargs=kwargs + ) + valid_indices = get_valid_indices(chunk_lengths, kwargs=kwargs) + cu_seqlens = get_audio_cu_seqlens( + chunk_lengths, feature_lens, self.n_window_infer, self.n_window, kwargs=kwargs ) # Add channel dim for Conv2d: (num_chunks, mel_bins, time) -> (num_chunks, 1, mel_bins, time) @@ -1184,14 +1200,14 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ - bilinear_indices = kwargs.pop("bilinear_indices", None) - bilinear_weights = kwargs.pop("bilinear_weights", None) - if bilinear_indices is None or bilinear_weights is None: - bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( - grid_thw, self.num_grid_per_side, self.config.spatial_merge_size - ) - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, + self.num_grid_per_side, + self.config.spatial_merge_size, + kwargs=kwargs, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.patch_embed(hidden_states) pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index 89a54c991ae2..646aef61f4b0 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -115,9 +115,9 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tenso def chunk_and_pad_features( - input_features: torch.Tensor, feature_lens: torch.Tensor, n_window: int + input_features: torch.Tensor, feature_lens: torch.Tensor, n_window: int, *, kwargs: dict | None = None ) -> tuple[torch.Tensor, torch.Tensor]: - """Split audio features into fixed-size chunks and pad to uniform length. + """Split audio features into fixed-size chunks and pad to uniform length, or pop precomputed pair from ``kwargs``. Each audio sample is split into chunks of ``n_window * 2`` frames (the last chunk may be shorter), then all chunks are right-padded to the longest chunk. @@ -126,11 +126,17 @@ def chunk_and_pad_features( input_features: ``(feature_dim, total_frames)`` concatenated audio features. feature_lens: ``(batch_size,)`` per-sample frame counts. n_window: half the target chunk size in frames. + kwargs: optional caller kwargs — if it contains both ``"padded_feature"`` and ``"chunk_lengths"`` they are popped and returned. Returns: ``padded_feature``: ``(num_chunks, feature_dim, max_chunk_len)`` padded chunks. ``chunk_lengths``: ``(num_chunks,)`` actual length of each chunk before padding. """ + if kwargs is not None: + padded_feature = kwargs.pop("padded_feature", None) + chunk_lengths = kwargs.pop("chunk_lengths", None) + if padded_feature is not None and chunk_lengths is not None: + return padded_feature, chunk_lengths chunk_num = torch.ceil(feature_lens / (n_window * 2)).long() chunk_lengths = torch.full((chunk_num.sum(),), n_window * 2, dtype=torch.long, device=feature_lens.device) tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] @@ -141,15 +147,18 @@ def chunk_and_pad_features( return padded_feature, chunk_lengths -def get_valid_indices(chunk_lengths: torch.Tensor) -> torch.Tensor: - """Compute flat indices of valid (non-padding) positions after CNN extraction. +def get_valid_indices(chunk_lengths: torch.Tensor, *, kwargs: dict | None = None) -> torch.Tensor: + """Compute flat indices of valid (non-padding) positions after CNN extraction, or pop ``"valid_indices"`` from ``kwargs`` if precomputed. Args: chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + kwargs: optional caller kwargs — if it contains ``"valid_indices"`` it is popped and returned. Returns: ``(total_valid,)`` flat indices into the ``(num_chunks * max_len_after_cnn)`` grid. """ + if kwargs is not None and (valid_indices := kwargs.pop("valid_indices", None)) is not None: + return valid_indices feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) max_len_after_cnn = feature_lens_after_cnn.max().item() mask = torch.arange(max_len_after_cnn, device=chunk_lengths.device) < feature_lens_after_cnn.unsqueeze(1) @@ -157,9 +166,14 @@ def get_valid_indices(chunk_lengths: torch.Tensor) -> torch.Tensor: def get_audio_cu_seqlens( - chunk_lengths: torch.Tensor, feature_lens: torch.Tensor, n_window_infer: int, n_window: int + chunk_lengths: torch.Tensor, + feature_lens: torch.Tensor, + n_window_infer: int, + n_window: int, + *, + kwargs: dict | None = None, ) -> torch.Tensor: - """Compute cumulative sequence lengths for audio attention windowing. + """Compute cumulative sequence lengths for audio attention windowing, or pop ``"cu_seqlens"`` from ``kwargs`` if precomputed. Splits each sample's post-CNN features into inference windows and returns cumulative boundaries for flash-attention-style sequence packing. @@ -169,10 +183,13 @@ def get_audio_cu_seqlens( feature_lens: ``(batch_size,)`` per-sample frame counts. n_window_infer: inference window size (in raw frames). n_window: half the chunk size (in raw frames). + kwargs: optional caller kwargs — if it contains ``"cu_seqlens"`` it is popped and returned. Returns: ``(num_windows + 1,)`` int32 cumulative sequence boundaries. """ + if kwargs is not None and (cu_seqlens := kwargs.pop("cu_seqlens", None)) is not None: + return cu_seqlens aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) max_len_after_cnn = feature_lens_after_cnn.max().item() @@ -992,13 +1009,12 @@ def forward(self, input_features=None, feature_lens=None, **kwargs: Unpack[Trans feature_lens (`torch.LongTensor` of shape `(batch_size,)`): mel length """ - padded_feature = kwargs.pop("padded_feature", None) - chunk_lengths = kwargs.pop("chunk_lengths", None) - if padded_feature is None: - padded_feature, chunk_lengths = chunk_and_pad_features(input_features, feature_lens, self.n_window) - valid_indices = kwargs.pop("valid_indices", None) or get_valid_indices(chunk_lengths) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_audio_cu_seqlens( - chunk_lengths, feature_lens, self.n_window_infer, self.n_window + padded_feature, chunk_lengths = chunk_and_pad_features( + input_features, feature_lens, self.n_window, kwargs=kwargs + ) + valid_indices = get_valid_indices(chunk_lengths, kwargs=kwargs) + cu_seqlens = get_audio_cu_seqlens( + chunk_lengths, feature_lens, self.n_window_infer, self.n_window, kwargs=kwargs ) # Add channel dim for Conv2d: (num_chunks, mel_bins, time) -> (num_chunks, 1, mel_bins, time) diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index ec6cdd57d5b1..5a6103724c0e 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -699,14 +699,14 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ - bilinear_indices = kwargs.pop("bilinear_indices", None) - bilinear_weights = kwargs.pop("bilinear_weights", None) - if bilinear_indices is None or bilinear_weights is None: - bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( - grid_thw, self.num_grid_per_side, self.config.spatial_merge_size - ) - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, + self.num_grid_per_side, + self.config.spatial_merge_size, + kwargs=kwargs, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.patch_embed(hidden_states) pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 626732ea0fcb..32b093caafb0 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -482,14 +482,14 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ - bilinear_indices = kwargs.pop("bilinear_indices", None) - bilinear_weights = kwargs.pop("bilinear_weights", None) - if bilinear_indices is None or bilinear_weights is None: - bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( - grid_thw, self.num_grid_per_side, self.config.spatial_merge_size - ) - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, + self.num_grid_per_side, + self.config.spatial_merge_size, + kwargs=kwargs, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.patch_embed(hidden_states) pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 1531ed2a7563..c3fc822ff068 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -684,14 +684,14 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ - bilinear_indices = kwargs.pop("bilinear_indices", None) - bilinear_weights = kwargs.pop("bilinear_weights", None) - if bilinear_indices is None or bilinear_weights is None: - bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( - grid_thw, self.num_grid_per_side, self.config.spatial_merge_size - ) - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, self.spatial_merge_size) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, + self.num_grid_per_side, + self.config.spatial_merge_size, + kwargs=kwargs, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.patch_embed(hidden_states) pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index 2996e354d829..18474c2739aa 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -422,8 +422,8 @@ def forward( merge_sizes (`torch.Tensor` of shape `(num_images_or_videos,)`): The spatial downsampling ratio of each image or video feature. """ - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, merge_sizes) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + position_ids = get_vision_position_ids(grid_thw, merge_sizes, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.embeddings(pixel_values.type(self.dtype)) rotary_pos_emb = self.rotary_pos_emb(position_ids) diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index 9d970b8c81f2..ed835ca5ed19 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -384,8 +384,8 @@ def forward( merge_sizes (`torch.Tensor` of shape `(num_images_or_videos,)`): The spatial downsampling ratio of each image or video feature. """ - position_ids = kwargs.pop("position_ids", None) or get_vision_position_ids(grid_thw, merge_sizes) - cu_seqlens = kwargs.pop("cu_seqlens", None) or get_vision_cu_seqlens(grid_thw) + position_ids = get_vision_position_ids(grid_thw, merge_sizes, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.embeddings(pixel_values.type(self.dtype)) rotary_pos_emb = self.rotary_pos_emb(position_ids) diff --git a/src/transformers/vision_utils.py b/src/transformers/vision_utils.py index 18d9c6b03da4..528dab1b81ef 100644 --- a/src/transformers/vision_utils.py +++ b/src/transformers/vision_utils.py @@ -18,6 +18,12 @@ ``grid_thw`` + config scalars. They are used by vision encoders and can be precomputed before `torch.compile` / ``torch.export`` tracing since they use untraceable ops (``repeat_interleave``, ``.tolist()``, ``nonzero()``, loops). + +Each ``get_*`` accepts an optional ``kwargs`` dict; if it contains the +precomputed tensor under the natural key (``"cu_seqlens"``, ``"position_ids"``, +…), the function pops and returns it instead of computing. Vision encoders +write ``x = get_vision_x(..., kwargs=kwargs)`` and the matching key is +removed from the caller's kwargs as a side-effect of the pop. """ from __future__ import annotations @@ -26,32 +32,40 @@ import torch.nn.functional as F -def get_vision_cu_seqlens(grid_thw: torch.Tensor) -> torch.Tensor: - """Compute cumulative sequence lengths from vision grid info. +def get_vision_cu_seqlens(grid_thw: torch.Tensor, *, kwargs: dict | None = None) -> torch.Tensor: + """Get cumulative sequence lengths from vision grid info, or pop from ``kwargs`` if precomputed. Args: grid_thw: ``(num_images_or_videos, 3)`` — temporal, height, width per entry. + kwargs: optional caller kwargs — if it contains ``"cu_seqlens"`` it is popped and returned. Returns: ``cu_seqlens``: ``(total_patches + 1,)`` int32 cumulative sequence boundaries. """ + if kwargs is not None and (cu_seqlens := kwargs.pop("cu_seqlens", None)) is not None: + return cu_seqlens cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32 ) return F.pad(cu_seqlens, (1, 0), value=0) -def get_vision_position_ids(grid_thw: torch.Tensor, spatial_merge_size: int | torch.Tensor) -> torch.Tensor: - """Compute (row, col) position IDs for vision rotary embeddings. +def get_vision_position_ids( + grid_thw: torch.Tensor, spatial_merge_size: int | torch.Tensor, *, kwargs: dict | None = None +) -> torch.Tensor: + """Get (row, col) position IDs for vision rotary embeddings, or pop from ``kwargs`` if precomputed. Args: grid_thw: ``(num_images_or_videos, 3)`` spatial_merge_size: merge block size — either a single ``int`` (same for all images) or a ``(num_images_or_videos,)`` tensor (per-image). + kwargs: optional caller kwargs — if it contains ``"position_ids"`` it is popped and returned. Returns: ``position_ids``: ``(total_tokens, 2)`` long — (row, col) position per token. """ + if kwargs is not None and (position_ids := kwargs.pop("position_ids", None)) is not None: + return position_ids device = grid_thw.device if isinstance(spatial_merge_size, int): spatial_merge_size = torch.tensor([spatial_merge_size], device=device).expand(len(grid_thw)) @@ -75,8 +89,10 @@ def get_vision_window_index( window_size: int, patch_size: int, spatial_merge_unit: int, + *, + kwargs: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Compute window attention indices for vision encoders with windowed attention. + """Get window attention indices, or pop ``"window_index"``/``"cu_window_seqlens"`` from ``kwargs`` if both precomputed. Args: grid_thw: ``(num_images_or_videos, 3)`` @@ -84,11 +100,17 @@ def get_vision_window_index( window_size: window size from vision config. patch_size: patch size from vision config. spatial_merge_unit: ``spatial_merge_size ** 2``. + kwargs: optional caller kwargs — if it contains both ``"window_index"`` and ``"cu_window_seqlens"`` they are popped and returned. Returns: ``window_index``: ``(total_tokens,)`` long — reorder indices for windowed attention. ``cu_window_seqlens``: ``(num_windows + 1,)`` int32 — cumulative window boundaries. """ + if kwargs is not None: + window_index = kwargs.pop("window_index", None) + cu_window_seqlens = kwargs.pop("cu_window_seqlens", None) + if window_index is not None and cu_window_seqlens is not None: + return window_index, cu_window_seqlens window_index: list = [] cu_window_seqlens: list = [0] window_index_id = 0 @@ -125,19 +147,29 @@ def get_vision_window_index( def get_vision_bilinear_indices_and_weights( - grid_thw: torch.Tensor, num_grid_per_side: int, spatial_merge_size: int + grid_thw: torch.Tensor, + num_grid_per_side: int, + spatial_merge_size: int, + *, + kwargs: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Compute bilinear interpolation indices and weights for position embeddings. + """Get bilinear interpolation indices/weights, or pop ``"bilinear_indices"``/``"bilinear_weights"`` from ``kwargs`` if both precomputed. Args: grid_thw: ``(num_images_or_videos, 3)`` num_grid_per_side: ``int(num_position_embeddings ** 0.5)`` from vision config. spatial_merge_size: merge block size from vision config. + kwargs: optional caller kwargs — if it contains both ``"bilinear_indices"`` and ``"bilinear_weights"`` they are popped and returned. Returns: ``bilinear_indices``: ``(4, total_thw)`` long — bilinear corner indices into pos_embed table. ``bilinear_weights``: ``(4, total_thw)`` float — interpolation weights. """ + if kwargs is not None: + bilinear_indices = kwargs.pop("bilinear_indices", None) + bilinear_weights = kwargs.pop("bilinear_weights", None) + if bilinear_indices is not None and bilinear_weights is not None: + return bilinear_indices, bilinear_weights N = num_grid_per_side m = spatial_merge_size device = grid_thw.device From 62d901c6e84059c3d52100702b34a179572834e7 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 29 Apr 2026 09:34:44 +0200 Subject: [PATCH 53/56] fixes --- .../models/video_llama_3/configuration_video_llama_3.py | 2 +- .../models/video_llama_3/modular_video_llama_3.py | 2 +- .../models/qwen3_omni_moe/test_modeling_qwen3_omni_moe.py | 2 +- tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py | 8 ++++---- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/video_llama_3/configuration_video_llama_3.py b/src/transformers/models/video_llama_3/configuration_video_llama_3.py index d25f1266284c..01f10067be1c 100644 --- a/src/transformers/models/video_llama_3/configuration_video_llama_3.py +++ b/src/transformers/models/video_llama_3/configuration_video_llama_3.py @@ -69,7 +69,7 @@ class VideoLlama3Config(PreTrainedConfig): vision_config: dict | PreTrainedConfig | None = None image_token_id: int = 151655 video_token_id: int = 151656 - tie_word_embeddings: bool = False + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.vision_config, dict): diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index ed835ca5ed19..94aac5dcdf72 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -106,7 +106,7 @@ class VideoLlama3Config(PreTrainedConfig): vision_config: dict | PreTrainedConfig | None = None image_token_id: int = 151655 video_token_id: int = 151656 - tie_word_embeddings: bool = False + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.vision_config, dict): diff --git a/tests/models/qwen3_omni_moe/test_modeling_qwen3_omni_moe.py b/tests/models/qwen3_omni_moe/test_modeling_qwen3_omni_moe.py index 7c3f5064b1ff..fd9bbf0021a6 100644 --- a/tests/models/qwen3_omni_moe/test_modeling_qwen3_omni_moe.py +++ b/tests/models/qwen3_omni_moe/test_modeling_qwen3_omni_moe.py @@ -859,7 +859,7 @@ def test_small_model_integration_test_multiturn(self): **inputs, thinker_temperature=0, thinker_do_sample=False, return_audio=False, thinker_max_new_tokens=20 ) - EXPECTED_DECODED_TEXT = "user\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog appears to be a Labrador Retriever.\nuser\nHow about this one?\nassistant\nThe sound is a person coughing." + EXPECTED_DECODED_TEXT = "user\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog appears to be a Labrador Retriever.\nuser\nHow about this one?\nassistant\nThis is the sound of a person coughing." self.assertEqual( self.processor.decode(output[0], skip_special_tokens=True), diff --git a/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py index 5946c3a607c4..eed96ebdfb7b 100644 --- a/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py +++ b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py @@ -420,7 +420,7 @@ def test_small_model_integration_test(self): inputs = inputs.to(torch_device) output = model.generate(**inputs, max_new_tokens=30, do_sample=False) - EXPECTED_DECODED_TEXT = "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a small wild cat native to the grasslands and steppes" + EXPECTED_DECODED_TEXT = "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a small wild cat native to the steppes and montane" self.assertEqual( self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT, @@ -498,7 +498,7 @@ def test_small_model_integration_test_expand(self): EXPECTED_DECODED_TEXT = [ "user\nWhat kind of dog is this?\nassistant\nThe animal in the image is not a dog. It is a **Pallas's cat** (*Otocolobus manul*), also known", - "user\nWhat kind of dog is this?\nassistant\nThe animal in the image is not a dog. It is a **Pallas's cat** (also known as the manul), a wild f" + "user\nWhat kind of dog is this?\nassistant\nThe animal in the image is not a dog. It is a **Pallas's cat** (also known as the **manul**), a" ] # fmt: skip self.assertEqual( self.processor.batch_decode(output, skip_special_tokens=True), @@ -548,7 +548,7 @@ def test_small_model_integration_test_batch_wo_image(self): output = model.generate(**inputs, max_new_tokens=30, do_sample=False) EXPECTED_DECODED_TEXT = [ - "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a wild cat species native to the grasslands and steppes", + "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a wild cat species native to the steppes and montane", "user\nWho are you?\nassistant\nI am Qwen, a large-scale language model developed by Alibaba Cloud's Tongyi Lab. I can assist you with answering questions, creating text such" ] # fmt: skip self.assertEqual( @@ -575,7 +575,7 @@ def test_small_model_integration_test_batch_different_resolutions(self): output = model.generate(**inputs, max_new_tokens=30, do_sample=False) EXPECTED_DECODED_TEXT = [ - "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a wild cat species native to the grasslands and steppes", + "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a wild cat species native to the grasslands and montane regions", "user\nWhat kind of dog is this?\nassistant\nBased on the image provided, the animals are not dogs. They are two cats.\n\nHere is a description of the animals in the image:\n\n- " ] # fmt: skip self.assertEqual( From 1481de69f5400e94cdb945cc5d499a5ea4a62a83 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 29 Apr 2026 12:01:54 +0200 Subject: [PATCH 54/56] video llama fixes --- tests/models/video_llama_3/test_modeling_video_llama_3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/video_llama_3/test_modeling_video_llama_3.py b/tests/models/video_llama_3/test_modeling_video_llama_3.py index 9ade7d43ed2d..8884a47c9b0f 100644 --- a/tests/models/video_llama_3/test_modeling_video_llama_3.py +++ b/tests/models/video_llama_3/test_modeling_video_llama_3.py @@ -883,7 +883,7 @@ def test_small_model_integration_test_batch_wo_image(self): EXPECTED_DECODED_TEXT = Expectations( { ("cuda", None): [ - "user\n\nDescribe the image.\nassistant\nThe image captures a vibrant nighttime scene on a bustling city street. A woman in a striking red dress", + "user\n\nDescribe the image.\nassistant\nThe image captures a vibrant night scene in a bustling Japanese city. A woman in a striking red dress", "user\nWhat is relativity?\nassistant\nRelativity is a scientific theory that describes the relationship between space and time. It was first proposed by", ], ("xpu", None): [ @@ -915,7 +915,7 @@ def test_small_model_integration_test_batch_different_resolutions(self): DECODED_TEXT = self.processor.batch_decode(output, skip_special_tokens=True) EXPECTED_DECODED_TEXT = [ - "user\n\nDescribe the image.\nassistant\nThe image captures a vibrant nighttime scene on a bustling city street. A woman in a striking red dress", + "user\n\nDescribe the image.\nassistant\nThe image captures a vibrant night scene in a bustling Japanese city. A woman in a striking red dress", "user\n\nDescribe the image.\nassistant\nThe image depicts a striking urban scene at night. A person is standing in the center of a wet", ] # fmt: skip From 3538f47d7af560eb666a705661872358b821c75d Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 29 Apr 2026 14:28:47 +0200 Subject: [PATCH 55/56] fix qwen3 vl --- tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py index eed96ebdfb7b..e3dda9cd7b17 100644 --- a/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py +++ b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py @@ -576,7 +576,7 @@ def test_small_model_integration_test_batch_different_resolutions(self): EXPECTED_DECODED_TEXT = [ "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a wild cat species native to the grasslands and montane regions", - "user\nWhat kind of dog is this?\nassistant\nBased on the image provided, the animals are not dogs. They are two cats.\n\nHere is a description of the animals in the image:\n\n- " + "user\nWhat kind of dog is this?\nassistant\nBased on the image provided, there is no dog present. The animals in the picture are two cats.\n\nHere is a description of the scene:\n-" ] # fmt: skip self.assertEqual( self.processor.batch_decode(output, skip_special_tokens=True), From e44ec7b93e3794abe030fdbf8696a940e14b1b41 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 29 Apr 2026 14:50:09 +0200 Subject: [PATCH 56/56] forgot glm ocr --- .../models/glm_ocr/modeling_glm_ocr.py | 16 +++------------- .../models/glm_ocr/modular_glm_ocr.py | 16 +++------------- 2 files changed, 6 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/glm_ocr/modeling_glm_ocr.py b/src/transformers/models/glm_ocr/modeling_glm_ocr.py index c68954ba59d7..ca633364b07d 100644 --- a/src/transformers/models/glm_ocr/modeling_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modeling_glm_ocr.py @@ -600,8 +600,6 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: r""" @@ -609,22 +607,14 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + hidden_states = self.patch_embed(hidden_states) rotary_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) diff --git a/src/transformers/models/glm_ocr/modular_glm_ocr.py b/src/transformers/models/glm_ocr/modular_glm_ocr.py index 63812ab0d255..782808bb69a9 100644 --- a/src/transformers/models/glm_ocr/modular_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modular_glm_ocr.py @@ -251,8 +251,6 @@ def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - position_ids: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: r""" @@ -260,22 +258,14 @@ def forward( The final hidden states of the model. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): The temporal, height and width of feature shape of each image in LLM. - cu_seqlens (`torch.Tensor`, *optional*): - Precomputed cumulative sequence lengths (from `get_vision_cu_seqlens`). - position_ids (`torch.Tensor`, *optional*): - Precomputed (row, col) position IDs (from `get_vision_position_ids`). Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - - if position_ids is None: - position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) - - if cu_seqlens is None: - cu_seqlens = get_vision_cu_seqlens(grid_thw) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + hidden_states = self.patch_embed(hidden_states) rotary_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin())