diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index f3d7bc590f5d..c15bca5033a4 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -19,6 +19,7 @@ # limitations under the License. import itertools +import warnings from collections.abc import Callable from typing import Any, Optional @@ -39,8 +40,14 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.output_capturing import OutputRecorder, capture_outputs +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_ernie4_5_vl_moe import Ernie4_5_VLMoeConfig, Ernie4_5_VLMoeTextConfig, Ernie4_5_VLMoeVisionConfig @@ -855,10 +862,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) @auto_docstring @@ -894,32 +899,13 @@ def __init__(self, config) -> None: self.post_init() def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) return rotary_pos_emb @merge_with_config_defaults @@ -931,21 +917,14 @@ def forward( grid_thw (`torch.LongTensor` of shape `(num_images, 3)`): The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values. """ + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - for block in self.blocks: hidden_states = block( hidden_states, @@ -1263,6 +1242,7 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -1288,6 +1268,7 @@ def get_video_features( video_outputs.pooler_output = video_embeds return video_outputs + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -1434,7 +1415,9 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds @@ -1442,7 +1425,9 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs + ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds @@ -1598,9 +1583,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) @auto_docstring def get_image_features( @@ -1615,7 +1598,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @auto_docstring @can_return_tuple diff --git a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py index ad47bc0508a3..a176dae49f45 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py @@ -20,7 +20,6 @@ import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from huggingface_hub.dataclasses import strict from torchvision.transforms.v2 import functional as tvF @@ -51,8 +50,9 @@ can_return_tuple, logging, ) -from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from ..ernie4_5_moe.configuration_ernie4_5_moe import Ernie4_5_MoeConfig from ..ernie4_5_moe.modeling_ernie4_5_moe import ( Ernie4_5_MoeAttention, @@ -701,21 +701,14 @@ def get_device(self): def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] ) -> tuple | BaseModelOutputWithPooling: + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - for block in self.blocks: hidden_states = block( hidden_states, @@ -962,6 +955,7 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -981,6 +975,7 @@ def get_video_features( video_outputs.pooler_output = video_embeds return video_outputs + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -1031,7 +1026,9 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds @@ -1039,7 +1036,9 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs + ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds diff --git a/src/transformers/models/glm46v/modeling_glm46v.py b/src/transformers/models/glm46v/modeling_glm46v.py index 81207e4c8608..90aa1bd8a568 100644 --- a/src/transformers/models/glm46v/modeling_glm46v.py +++ b/src/transformers/models/glm46v/modeling_glm46v.py @@ -37,6 +37,7 @@ can_return_tuple, torch_compilable_check, ) +from ...utils.generic import handle_extra_kwargs from ..auto import AutoModel from .configuration_glm46v import Glm46VConfig @@ -254,6 +255,7 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -270,12 +272,12 @@ def get_video_features( """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames - temp_frames_hw = [] - video_grid_thw_list = video_grid_thw.tolist() - for t, h, w in video_grid_thw_list: - repeated_row = torch.tensor([1, h, w]).unsqueeze(0).repeat(t, 1) - temp_frames_hw.append(repeated_row) - flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) + t = video_grid_thw[:, 0] + hw = video_grid_thw[:, 1:] + # repeat each (h,w) row `t` times + flattened_hw = torch.repeat_interleave(hw, t, dim=0) + prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) + flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) vision_outputs = self.visual( pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs ) @@ -285,6 +287,7 @@ def get_video_features( return vision_outputs + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -300,7 +303,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds @@ -430,13 +433,17 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs + ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) @@ -514,6 +521,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @auto_docstring def get_video_features( self, @@ -527,10 +535,9 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @auto_docstring def get_image_features( self, @@ -544,7 +551,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 6121dc8d3fe8..bda42b8ec99d 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -40,8 +41,14 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_glm4v import Glm4vConfig, Glm4vTextConfig, Glm4vVisionConfig @@ -110,10 +117,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Glm4vVisionPatchMerger(nn.Module): @@ -180,12 +185,11 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torc ) # Calculate target dimensions for each patch - target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) - target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) + num_tokens = embeddings.shape[0] + token_positions = torch.arange(num_tokens, device=embeddings.device) + seg_ids = (token_positions.unsqueeze(0) >= lengths.cumsum(0).unsqueeze(1)).sum(0) + target_h = image_shapes[seg_ids, 1].to(dtype=torch.float32) + target_w = image_shapes[seg_ids, 2].to(dtype=torch.float32) # Normalize coordinates to [-1, 1] range for grid_sample norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 @@ -728,33 +732,14 @@ def __init__(self, config) -> None: self.post_init() def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb, pos_ids + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb, position_ids @merge_with_config_defaults @capture_outputs @@ -771,28 +756,22 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + hidden_states = self.patch_embed(hidden_states) hidden_states = self.post_conv_layernorm(hidden_states) - rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + rotary_emb = self.rotary_pos_emb(position_ids) + emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] hidden_states = self.embeddings( hidden_states, seqlens, grid_thw, - image_type_ids[:, 0].to(hidden_states.device), - image_type_ids[:, 1].to(hidden_states.device), + position_ids[:, 0].to(hidden_states.device), + position_ids[:, 1].to(hidden_states.device), ) for blk in self.blocks: @@ -1097,6 +1076,7 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -1113,12 +1093,12 @@ def get_video_features( """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames - temp_frames_hw = [] - video_grid_thw_list = video_grid_thw.tolist() - for t, h, w in video_grid_thw_list: - repeated_row = torch.tensor([1, h, w]).unsqueeze(0).repeat(t, 1) - temp_frames_hw.append(repeated_row) - flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) + t = video_grid_thw[:, 0] + hw = video_grid_thw[:, 1:] + # repeat each (h,w) row `t` times + flattened_hw = torch.repeat_interleave(hw, t, dim=0) + prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) + flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) vision_outputs = self.visual( pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs ) @@ -1128,6 +1108,7 @@ def get_video_features( return vision_outputs + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -1143,7 +1124,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds @@ -1273,13 +1254,17 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs + ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) @@ -1357,6 +1342,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @auto_docstring def get_video_features( self, @@ -1370,10 +1356,9 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @auto_docstring def get_image_features( self, @@ -1387,7 +1372,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index d4a34a1952ad..f7f41ceb00b0 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from collections.abc import Callable import numpy as np @@ -41,9 +42,10 @@ logging, torch_compilable_check, ) -from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ...video_utils import VideoInput +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from ..glm4.modeling_glm4 import Glm4MLP, Glm4RMSNorm, Glm4RotaryEmbedding, eager_attention_forward from ..qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VisionPatchEmbed, @@ -314,12 +316,11 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torc ) # Calculate target dimensions for each patch - target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) - target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) + num_tokens = embeddings.shape[0] + token_positions = torch.arange(num_tokens, device=embeddings.device) + seg_ids = (token_positions.unsqueeze(0) >= lengths.cumsum(0).unsqueeze(1)).sum(0) + target_h = image_shapes[seg_ids, 1].to(dtype=torch.float32) + target_w = image_shapes[seg_ids, 2].to(dtype=torch.float32) # Normalize coordinates to [-1, 1] range for grid_sample norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 @@ -610,33 +611,14 @@ def __init__(self, config) -> None: self.post_init() def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb, pos_ids + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb, position_ids @merge_with_config_defaults @capture_outputs @@ -653,28 +635,22 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + hidden_states = self.patch_embed(hidden_states) hidden_states = self.post_conv_layernorm(hidden_states) - rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + rotary_emb = self.rotary_pos_emb(position_ids) + emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] hidden_states = self.embeddings( hidden_states, seqlens, grid_thw, - image_type_ids[:, 0].to(hidden_states.device), - image_type_ids[:, 1].to(hidden_states.device), + position_ids[:, 0].to(hidden_states.device), + position_ids[:, 1].to(hidden_states.device), ) for blk in self.blocks: @@ -804,6 +780,7 @@ def __init__(self, config): super().__init__(config) self.visual = Glm4vVisionModel._from_config(config.vision_config) + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -820,12 +797,12 @@ def get_video_features( """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames - temp_frames_hw = [] - video_grid_thw_list = video_grid_thw.tolist() - for t, h, w in video_grid_thw_list: - repeated_row = torch.tensor([1, h, w]).unsqueeze(0).repeat(t, 1) - temp_frames_hw.append(repeated_row) - flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) + t = video_grid_thw[:, 0] + hw = video_grid_thw[:, 1:] + # repeat each (h,w) row `t` times + flattened_hw = torch.repeat_interleave(hw, t, dim=0) + prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) + flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) vision_outputs = self.visual( pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs ) @@ -946,13 +923,17 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs + ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 3bf3dc157d3f..405a5de64833 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -40,8 +41,15 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, torch_compilable_check -from ...utils.generic import can_return_tuple, is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + can_return_tuple, + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_glm4v_moe import Glm4vMoeConfig, Glm4vMoeTextConfig, Glm4vMoeVisionConfig @@ -469,10 +477,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) @use_kernel_forward_from_hub("RMSNorm") @@ -594,12 +600,11 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torc ) # Calculate target dimensions for each patch - target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) - target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) + num_tokens = embeddings.shape[0] + token_positions = torch.arange(num_tokens, device=embeddings.device) + seg_ids = (token_positions.unsqueeze(0) >= lengths.cumsum(0).unsqueeze(1)).sum(0) + target_h = image_shapes[seg_ids, 1].to(dtype=torch.float32) + target_w = image_shapes[seg_ids, 2].to(dtype=torch.float32) # Normalize coordinates to [-1, 1] range for grid_sample norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 @@ -792,33 +797,14 @@ def __init__(self, config) -> None: self.post_init() def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb, pos_ids + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb, position_ids @merge_with_config_defaults @capture_outputs @@ -835,28 +821,22 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + hidden_states = self.patch_embed(hidden_states) hidden_states = self.post_conv_layernorm(hidden_states) - rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + rotary_emb = self.rotary_pos_emb(position_ids) + emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] hidden_states = self.embeddings( hidden_states, seqlens, grid_thw, - image_type_ids[:, 0].to(hidden_states.device), - image_type_ids[:, 1].to(hidden_states.device), + position_ids[:, 0].to(hidden_states.device), + position_ids[:, 1].to(hidden_states.device), ) for blk in self.blocks: @@ -1266,6 +1246,7 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -1282,12 +1263,12 @@ def get_video_features( """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames - temp_frames_hw = [] - video_grid_thw_list = video_grid_thw.tolist() - for t, h, w in video_grid_thw_list: - repeated_row = torch.tensor([1, h, w]).unsqueeze(0).repeat(t, 1) - temp_frames_hw.append(repeated_row) - flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) + t = video_grid_thw[:, 0] + hw = video_grid_thw[:, 1:] + # repeat each (h,w) row `t` times + flattened_hw = torch.repeat_interleave(hw, t, dim=0) + prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) + flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) vision_outputs = self.visual( pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs ) @@ -1297,6 +1278,7 @@ def get_video_features( return vision_outputs + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -1312,7 +1294,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds @@ -1442,13 +1424,17 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs + ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) @@ -1581,6 +1567,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @auto_docstring def get_video_features( self, @@ -1594,10 +1581,9 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @auto_docstring def get_image_features( self, @@ -1611,7 +1597,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @auto_docstring @can_return_tuple diff --git a/src/transformers/models/glm_image/modeling_glm_image.py b/src/transformers/models/glm_image/modeling_glm_image.py index 012da8513453..601f4f97c583 100644 --- a/src/transformers/models/glm_image/modeling_glm_image.py +++ b/src/transformers/models/glm_image/modeling_glm_image.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -38,8 +39,9 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_glm_image import GlmImageConfig, GlmImageTextConfig, GlmImageVisionConfig, GlmImageVQVAEConfig @@ -236,12 +238,11 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torc ) # Calculate target dimensions for each patch - target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) - target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to( - device=device, dtype=torch.float32 - ) + num_tokens = embeddings.shape[0] + token_positions = torch.arange(num_tokens, device=embeddings.device) + seg_ids = (token_positions.unsqueeze(0) >= lengths.cumsum(0).unsqueeze(1)).sum(0) + target_h = image_shapes[seg_ids, 1].to(dtype=torch.float32) + target_w = image_shapes[seg_ids, 2].to(dtype=torch.float32) # Normalize coordinates to [-1, 1] range for grid_sample norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 @@ -590,30 +591,12 @@ def __init__(self, config: GlmImageVisionConfig) -> None: self.post_init() def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - return pos_ids + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + return get_vision_position_ids(grid_thw, self.spatial_merge_size) @merge_with_config_defaults @capture_outputs @@ -630,22 +613,17 @@ def forward( Returns: `torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states. """ + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.patch_embed(pixel_values) - image_type_ids = self.rot_pos_emb(grid_thw) - - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] hidden_states = self.embeddings( hidden_states, seqlens, grid_thw, - image_type_ids[:, 0].to(hidden_states.device), - image_type_ids[:, 1].to(hidden_states.device), + position_ids[:, 0].to(hidden_states.device), + position_ids[:, 1].to(hidden_states.device), ) # Transformer blocks (no position_embeddings needed, already added above) @@ -1184,6 +1162,7 @@ def get_rope_index( return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -1341,7 +1320,7 @@ def forward( # Fallback for batch_size=1: all but last grid are source images source_grids = image_grid_thw[:-1] - image_features = self.get_image_features(pixel_values, source_grids, return_dict=True) + image_features = self.get_image_features(pixel_values, source_grids, return_dict=True, **kwargs) image_embeds = torch.cat(image_features.pooler_output, dim=0) image_ids = self.get_image_tokens(image_embeds, source_grids) image_ids = image_ids.view(-1).to(input_ids.device) diff --git a/src/transformers/models/glm_image/modular_glm_image.py b/src/transformers/models/glm_image/modular_glm_image.py index e72aede3da66..b71be80c71ea 100644 --- a/src/transformers/models/glm_image/modular_glm_image.py +++ b/src/transformers/models/glm_image/modular_glm_image.py @@ -13,6 +13,7 @@ # limitations under the License. import math +import warnings from collections.abc import Callable from typing import Any @@ -31,9 +32,10 @@ from ...processing_utils import ImagesKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, merge_with_config_defaults from ...utils.import_utils import requires from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from ..chameleon.modeling_chameleon import ChameleonVQVAE, ChameleonVQVAEModelOutput, ChameleonVQVAEVectorQuantizer from ..glm4v.configuration_glm4v import Glm4vTextConfig, Glm4vVisionConfig from ..glm4v.modeling_glm4v import ( @@ -429,30 +431,12 @@ def __init__(self, config: GlmImageVisionConfig): del self.post_layernorm def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - return pos_ids + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + return get_vision_position_ids(grid_thw, self.spatial_merge_size) @merge_with_config_defaults @capture_outputs @@ -469,22 +453,17 @@ def forward( Returns: `torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states. """ + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.patch_embed(pixel_values) - image_type_ids = self.rot_pos_emb(grid_thw) - - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] hidden_states = self.embeddings( hidden_states, seqlens, grid_thw, - image_type_ids[:, 0].to(hidden_states.device), - image_type_ids[:, 1].to(hidden_states.device), + position_ids[:, 0].to(hidden_states.device), + position_ids[:, 1].to(hidden_states.device), ) # Transformer blocks (no position_embeddings needed, already added above) @@ -720,6 +699,7 @@ def get_image_tokens( def get_video_features(self): raise AttributeError("Not needed for GlmImage") + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -875,7 +855,7 @@ def forward( # Fallback for batch_size=1: all but last grid are source images source_grids = image_grid_thw[:-1] - image_features = self.get_image_features(pixel_values, source_grids, return_dict=True) + image_features = self.get_image_features(pixel_values, source_grids, return_dict=True, **kwargs) image_embeds = torch.cat(image_features.pooler_output, dim=0) image_ids = self.get_image_tokens(image_embeds, source_grids) image_ids = image_ids.view(-1).to(input_ids.device) diff --git a/src/transformers/models/glm_ocr/modeling_glm_ocr.py b/src/transformers/models/glm_ocr/modeling_glm_ocr.py index 828a99a705b5..ca633364b07d 100644 --- a/src/transformers/models/glm_ocr/modeling_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modeling_glm_ocr.py @@ -19,13 +19,13 @@ # limitations under the License. import itertools +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional import torch import torch.nn as nn -import torch.nn.functional as F from torch.nn import LayerNorm from ... import initialization as init @@ -41,8 +41,14 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_glm_ocr import GlmOcrConfig, GlmOcrTextConfig, GlmOcrVisionConfig @@ -312,10 +318,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) @auto_docstring @@ -580,38 +584,24 @@ def __init__(self, config) -> None: self.post_init() def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb, pos_ids + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb, position_ids @merge_with_config_defaults @capture_outputs @auto_docstring - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + **kwargs, + ) -> torch.Tensor: r""" hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): The final hidden states of the model. @@ -621,21 +611,14 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) Returns: `torch.Tensor`: hidden_states. """ + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + rotary_emb = self.rotary_pos_emb(position_ids) + emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - for blk in self.blocks: hidden_states = blk( hidden_states, @@ -1013,6 +996,7 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -1029,12 +1013,12 @@ def get_video_features( """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames - temp_frames_hw = [] - video_grid_thw_list = video_grid_thw.tolist() - for t, h, w in video_grid_thw_list: - repeated_row = torch.tensor([1, h, w]).unsqueeze(0).repeat(t, 1) - temp_frames_hw.append(repeated_row) - flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0) + t = video_grid_thw[:, 0] + hw = video_grid_thw[:, 1:] + # repeat each (h,w) row `t` times + flattened_hw = torch.repeat_interleave(hw, t, dim=0) + prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1) + flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1) vision_outputs = self.visual( pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs ) @@ -1044,6 +1028,7 @@ def get_video_features( return vision_outputs + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -1059,7 +1044,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds @@ -1189,13 +1174,17 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs + ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) @@ -1273,6 +1262,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @auto_docstring def get_video_features( self, @@ -1286,10 +1276,9 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @auto_docstring def get_image_features( self, @@ -1303,7 +1292,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/glm_ocr/modular_glm_ocr.py b/src/transformers/models/glm_ocr/modular_glm_ocr.py index 2f71dded711d..782808bb69a9 100644 --- a/src/transformers/models/glm_ocr/modular_glm_ocr.py +++ b/src/transformers/models/glm_ocr/modular_glm_ocr.py @@ -16,12 +16,12 @@ import torch import torch.nn as nn -import torch.nn.functional as F from huggingface_hub.dataclasses import strict from ...modeling_outputs import BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...utils import auto_docstring +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from ..glm4v.configuration_glm4v import Glm4vConfig, Glm4vTextConfig, Glm4vVisionConfig from ..glm4v.modeling_glm4v import ( Glm4vForConditionalGeneration, @@ -247,7 +247,12 @@ def __init__(self, config) -> None: hidden_act=config.hidden_act, ) - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + **kwargs, + ) -> torch.Tensor: r""" hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): The final hidden states of the model. @@ -257,21 +262,14 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) Returns: `torch.Tensor`: hidden_states. """ + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + rotary_emb = self.rotary_pos_emb(position_ids) + emb = torch.cat((rotary_emb, rotary_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - for blk in self.blocks: hidden_states = blk( hidden_states, diff --git a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py index 0ae254feef39..d54259fe2902 100644 --- a/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py @@ -44,8 +44,14 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check, torch_int -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_paddleocr_vl import PaddleOCRTextConfig, PaddleOCRVisionConfig, PaddleOCRVLConfig @@ -98,10 +104,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class PaddleOCRRotaryEmbedding(nn.Module): @@ -588,13 +592,13 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: def forward( self, pixel_values: torch.FloatTensor, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + grid_thw: torch.LongTensor | None = None, ) -> torch.Tensor: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): The tensors corresponding to the input images. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ batch_size, squence_len, channel, height, width = pixel_values.shape @@ -607,8 +611,7 @@ def forward( start = 0 embeddings = embeddings.squeeze(0) tmp_embeddings = [] - for image_grid in image_grid_thw: - t, h, w = image_grid + for t, h, w in grid_thw: end = start + t * h * w image_embeddings = embeddings[start:end, :] position_embedding = self.interpolate_pos_encoding(image_embeddings, h, w).squeeze(0).repeat(t, 1) @@ -821,9 +824,8 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, inputs_embeds: torch.FloatTensor, - cu_seqlens: torch.Tensor, + grid_thw: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: r""" @@ -831,35 +833,23 @@ def forward( Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): - The cumulative sequence lengths of each image or video feature. + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): - The temporal, height and width of feature shape of each image in LLM. """ - device = inputs_embeds.device + # Use merge_size=1: PaddleOCR merges patches in the projector (after the encoder), + # unlike Qwen which merges inside the encoder, so rotary positions here are simple (row, col). + position_ids = get_vision_position_ids(grid_thw, 1, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + hidden_states = inputs_embeds attention_mask = create_bidirectional_mask( config=self.config, inputs_embeds=inputs_embeds, attention_mask=attention_mask, ) - split_hids = [] - split_wids = [] - for t, h, w in image_grid_thw: - image_pids = torch.arange(t * h * w, device=device) % (h * w) - sample_hids = image_pids // w - sample_wids = image_pids % w - split_hids.append(sample_hids) - split_wids.append(sample_wids) - width_position_ids = torch.concat(split_wids, dim=0) - height_position_ids = torch.concat(split_hids, dim=0) - - pids = torch.stack([height_position_ids, width_position_ids], dim=-1) - max_grid_size = pids.max() + 1 - rotary_embeddings_max_grid = self.rotary_pos_emb(max_grid_size) - rotary_embeddings = rotary_embeddings_max_grid[pids].flatten(1) + rotary_embeddings = self.rotary_pos_emb(position_ids) rotary_embeddings = rotary_embeddings.repeat(1, 2) position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin()) @@ -901,29 +891,24 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, pixel_values: torch.FloatTensor, - cu_seqlens: torch.Tensor, + grid_thw: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size * patch_size * image_channels)`): The tensors corresponding to the input images. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): - The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - hidden_states = self.embeddings(pixel_values, image_grid_thw=image_grid_thw) - + hidden_states = self.embeddings(pixel_values, grid_thw=grid_thw) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, - cu_seqlens=cu_seqlens, + grid_thw=grid_thw, attention_mask=attention_mask, - image_grid_thw=image_grid_thw, **kwargs, ) @@ -952,25 +937,17 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, pixel_values: torch.FloatTensor, - cu_seqlens: torch.Tensor, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + grid_thw: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): The tensors corresponding to the input images. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): - The cumulative sequence lengths of each image or video feature. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.vision_model( - pixel_values=pixel_values, - cu_seqlens=cu_seqlens, - image_grid_thw=image_grid_thw, - **kwargs, - ) + return self.vision_model(pixel_values=pixel_values, grid_thw=grid_thw, **kwargs) @dataclass @@ -1211,6 +1188,7 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -1226,22 +1204,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype).unsqueeze(0) - cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=image_grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) - vision_outputs = self.visual( - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - cu_seqlens=cu_seqlens, - return_dict=True, - **kwargs, - ) + vision_outputs = self.visual(pixel_values=pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) image_embeds = vision_outputs.last_hidden_state image_embeds = self.projector(image_embeds, image_grid_thw) vision_outputs.pooler_output = image_embeds @@ -1346,7 +1309,9 @@ def forward( inputs_embeds = self.language_model.embed_tokens(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) @@ -1399,6 +1364,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="image") @auto_docstring def get_image_features( self, @@ -1412,7 +1378,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py index 02895d6e2576..16051aa7305a 100644 --- a/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py +++ b/src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py @@ -54,8 +54,9 @@ torch_compilable_check, torch_int, ) -from ...utils.generic import merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from ..ernie4_5.configuration_ernie4_5 import Ernie4_5Config from ..ernie4_5.modeling_ernie4_5 import ( Ernie4_5DecoderLayer, @@ -722,13 +723,13 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: def forward( self, pixel_values: torch.FloatTensor, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + grid_thw: torch.LongTensor | None = None, ) -> torch.Tensor: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): The tensors corresponding to the input images. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ batch_size, squence_len, channel, height, width = pixel_values.shape @@ -741,8 +742,7 @@ def forward( start = 0 embeddings = embeddings.squeeze(0) tmp_embeddings = [] - for image_grid in image_grid_thw: - t, h, w = image_grid + for t, h, w in grid_thw: end = start + t * h * w image_embeddings = embeddings[start:end, :] position_embedding = self.interpolate_pos_encoding(image_embeddings, h, w).squeeze(0).repeat(t, 1) @@ -780,9 +780,8 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, inputs_embeds: torch.FloatTensor, - cu_seqlens: torch.Tensor, + grid_thw: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: r""" @@ -790,35 +789,23 @@ def forward( Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): - The cumulative sequence lengths of each image or video feature. + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): - The temporal, height and width of feature shape of each image in LLM. """ - device = inputs_embeds.device + # Use merge_size=1: PaddleOCR merges patches in the projector (after the encoder), + # unlike Qwen which merges inside the encoder, so rotary positions here are simple (row, col). + position_ids = get_vision_position_ids(grid_thw, 1, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + hidden_states = inputs_embeds attention_mask = create_bidirectional_mask( config=self.config, inputs_embeds=inputs_embeds, attention_mask=attention_mask, ) - split_hids = [] - split_wids = [] - for t, h, w in image_grid_thw: - image_pids = torch.arange(t * h * w, device=device) % (h * w) - sample_hids = image_pids // w - sample_wids = image_pids % w - split_hids.append(sample_hids) - split_wids.append(sample_wids) - width_position_ids = torch.concat(split_wids, dim=0) - height_position_ids = torch.concat(split_hids, dim=0) - - pids = torch.stack([height_position_ids, width_position_ids], dim=-1) - max_grid_size = pids.max() + 1 - rotary_embeddings_max_grid = self.rotary_pos_emb(max_grid_size) - rotary_embeddings = rotary_embeddings_max_grid[pids].flatten(1) + rotary_embeddings = self.rotary_pos_emb(position_ids) rotary_embeddings = rotary_embeddings.repeat(1, 2) position_embeddings = (rotary_embeddings.cos(), rotary_embeddings.sin()) @@ -860,29 +847,24 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, pixel_values: torch.FloatTensor, - cu_seqlens: torch.Tensor, + grid_thw: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size * patch_size * image_channels)`): The tensors corresponding to the input images. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): - The cumulative sequence lengths of each image or video feature. attention_mask (`torch.Tensor`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - hidden_states = self.embeddings(pixel_values, image_grid_thw=image_grid_thw) - + hidden_states = self.embeddings(pixel_values, grid_thw=grid_thw) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, - cu_seqlens=cu_seqlens, + grid_thw=grid_thw, attention_mask=attention_mask, - image_grid_thw=image_grid_thw, **kwargs, ) @@ -911,25 +893,17 @@ def __init__(self, config: PaddleOCRVisionConfig): def forward( self, pixel_values: torch.FloatTensor, - cu_seqlens: torch.Tensor, - image_grid_thw: list[tuple[int, int, int] | list[tuple[int, int, int]]] | None = None, + grid_thw: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, image_channels, patch_size, patch_size)`): The tensors corresponding to the input images. - cu_seqlens (`torch.Tensor` of shape `(num_images + 1,)`): - The cumulative sequence lengths of each image or video feature. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.vision_model( - pixel_values=pixel_values, - cu_seqlens=cu_seqlens, - image_grid_thw=image_grid_thw, - **kwargs, - ) + return self.vision_model(pixel_values=pixel_values, grid_thw=grid_thw, **kwargs) class PaddleOCRVLModelOutputWithPast(Qwen2VLModelOutputWithPast): @@ -961,6 +935,7 @@ def set_input_embeddings(self, value): def get_video_features(self): raise AttributeError("PaddleOCRVLModel does not support video.") + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -976,22 +951,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype).unsqueeze(0) - cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=image_grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) - vision_outputs = self.visual( - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - cu_seqlens=cu_seqlens, - return_dict=True, - **kwargs, - ) + vision_outputs = self.visual(pixel_values=pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) image_embeds = vision_outputs.last_hidden_state image_embeds = self.projector(image_embeds, image_grid_thw) vision_outputs.pooler_output = image_embeds @@ -1047,7 +1007,9 @@ def forward( inputs_embeds = self.language_model.embed_tokens(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index c8824b2f9730..a1fb96847d19 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -20,6 +20,7 @@ # limitations under the License. import math +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -51,9 +52,15 @@ torch_compilable_check, ) from ...utils.deprecation import deprecate_kwarg -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.hub import cached_file from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids, get_vision_window_index from .configuration_qwen2_5_omni import ( Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniBigVGANConfig, @@ -601,10 +608,9 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, + cu_seqlens: torch.Tensor, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + ) -> torch.Tensor: """Input shape: Batch x Time x Channel""" seq_length, _ = hidden_states.size() @@ -616,27 +622,50 @@ def forward( query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seq_lens_k=cu_seqlens, - max_length_q=max_seqlen, - max_length_k=max_seqlen, - is_causal=False, - **kwargs, - ) + if is_flash_attention_requested(self.config): + # Flash Attention: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.out_proj(attn_output) @@ -661,7 +690,6 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - attention_mask: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: """ @@ -675,7 +703,6 @@ def forward( hidden_states = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, - attention_mask=attention_mask, **kwargs, ) hidden_states = residual + hidden_states @@ -716,6 +743,97 @@ def forward(self, seqlen: int): return self.positional_embedding[:seqlen, :] +def chunk_and_pad_features( + input_features: torch.Tensor, feature_lens: torch.Tensor, n_window: int, *, kwargs: dict | None = None +) -> tuple[torch.Tensor, torch.Tensor]: + """Split audio features into fixed-size chunks and pad to uniform length, or pop precomputed pair from ``kwargs``. + + Each audio sample is split into chunks of ``n_window * 2`` frames (the last + chunk may be shorter), then all chunks are right-padded to the longest chunk. + + Args: + input_features: ``(feature_dim, total_frames)`` concatenated audio features. + feature_lens: ``(batch_size,)`` per-sample frame counts. + n_window: half the target chunk size in frames. + kwargs: optional caller kwargs — if it contains both ``"padded_feature"`` and ``"chunk_lengths"`` they are popped and returned. + + Returns: + ``padded_feature``: ``(num_chunks, feature_dim, max_chunk_len)`` padded chunks. + ``chunk_lengths``: ``(num_chunks,)`` actual length of each chunk before padding. + """ + if kwargs is not None: + padded_feature = kwargs.pop("padded_feature", None) + chunk_lengths = kwargs.pop("chunk_lengths", None) + if padded_feature is not None and chunk_lengths is not None: + return padded_feature, chunk_lengths + chunk_num = torch.ceil(feature_lens / (n_window * 2)).long() + chunk_lengths = torch.full((chunk_num.sum(),), n_window * 2, dtype=torch.long, device=feature_lens.device) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (n_window * 2) + chunk_lengths = torch.where(chunk_lengths == 0, n_window * 2, chunk_lengths) + chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) + padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) + return padded_feature, chunk_lengths + + +def get_audio_cu_seqlens(chunk_lengths: torch.Tensor, *, kwargs: dict | None = None) -> torch.Tensor: + """Compute cumulative sequence lengths for audio attention, or pop ``"cu_seqlens"`` from ``kwargs`` if precomputed. + + Applies one stride-2 convolution length reduction, then returns cumulative + boundaries for flash-attention-style sequence packing. + + Args: + chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + kwargs: optional caller kwargs — if it contains ``"cu_seqlens"`` it is popped and returned. + + Returns: + ``(num_chunks + 1,)`` int32 cumulative sequence boundaries. + """ + if kwargs is not None and (cu_seqlens := kwargs.pop("cu_seqlens", None)) is not None: + return cu_seqlens + after_conv1 = (chunk_lengths - 1) // 2 + 1 + return F.pad(after_conv1.cumsum(0), (1, 0), value=0).to(torch.int32) + + +def get_valid_indices(chunk_lengths: torch.Tensor, *, kwargs: dict | None = None) -> torch.Tensor: + """Compute flat indices of valid (non-padding) positions after one stride-2 conv, or pop ``"valid_indices"`` from ``kwargs`` if precomputed. + + Args: + chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + kwargs: optional caller kwargs — if it contains ``"valid_indices"`` it is popped and returned. + + Returns: + ``(total_valid,)`` flat indices into the ``(num_chunks * max_len_after_conv)`` grid. + """ + if kwargs is not None and (valid_indices := kwargs.pop("valid_indices", None)) is not None: + return valid_indices + after_conv1 = (chunk_lengths - 1) // 2 + 1 + max_len = after_conv1.max().item() + mask = torch.arange(max_len, device=chunk_lengths.device) < after_conv1.unsqueeze(1) + return mask.flatten().nonzero().squeeze(-1) + + +def get_pool_indices(feature_lens: torch.Tensor, *, kwargs: dict | None = None) -> torch.Tensor: + """Compute indices for post-encoder stride-2 average pooling, or pop ``"pool_indices"`` from ``kwargs`` if precomputed. + + Args: + feature_lens: ``(batch_size,)`` mel spectrogram lengths. + kwargs: optional caller kwargs — if it contains ``"pool_indices"`` it is popped and returned. + + Returns: + ``(total_pooled,)`` flat index of first element of each stride-2 pair. + """ + if kwargs is not None and (pool_indices := kwargs.pop("pool_indices", None)) is not None: + return pool_indices + after_conv1 = (feature_lens - 1) // 2 + 1 + num_pooled = (after_conv1 - 2) // 2 + 1 + offsets = F.pad(after_conv1[:-1].cumsum(0), (1, 0), value=0) + pair_offsets = torch.repeat_interleave(offsets, num_pooled) + local_indices = torch.arange(num_pooled.sum(), device=feature_lens.device) + local_indices -= torch.repeat_interleave(F.pad(num_pooled[:-1].cumsum(0), (1, 0), value=0), num_pooled) + return pair_offsets + local_indices * 2 + + @auto_docstring( custom_intro=""" Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a @@ -765,85 +883,71 @@ def get_input_embeddings(self) -> nn.Module: def set_input_embeddings(self, value: nn.Module): self.conv1 = value - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if is_flash_attention_requested(self.config): - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - @merge_with_config_defaults @capture_outputs(tie_last_hidden_states=False) @auto_docstring - def forward(self, input_features, feature_lens=None, aftercnn_lens=None, **kwargs: Unpack[TransformersKwargs]): + def forward( + self, input_features=None, feature_lens=None, aftercnn_lens=None, **kwargs: Unpack[TransformersKwargs] + ): r""" feature_lens (`torch.LongTensor` of shape `(batch_size,)`): mel length - aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`): + aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`, *optional*): mel length after cnn """ - chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() - - chunk_lengths = torch.full((chunk_num.sum(),), self.n_window * 2, dtype=torch.long, device=feature_lens.device) - tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] - chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) - chunk_lengths = torch.where(chunk_lengths == 0, self.n_window * 2, chunk_lengths) + padded_feature, chunk_lengths = chunk_and_pad_features( + input_features, feature_lens, self.n_window, kwargs=kwargs + ) + valid_indices = get_valid_indices(chunk_lengths, kwargs=kwargs) + pool_indices = get_pool_indices(feature_lens, kwargs=kwargs) + cu_seqlens = get_audio_cu_seqlens(chunk_lengths, kwargs=kwargs) - chunk_list = input_features.split(chunk_lengths.tolist(), dim=1) - padded_feature, padded_mask, padded_mask_after_cnn = self.padded_and_mask_function( - chunk_list, chunk_lengths, padding_value=0, padding_side="right" + # Derive masks from chunk_lengths (traceable arithmetic + arange broadcasting) + padded_feature = padded_feature.to(self.conv1.weight.dtype) + padded_mask = ( + (torch.arange(padded_feature.shape[2], device=padded_feature.device) < chunk_lengths.unsqueeze(1)) + .unsqueeze(1) + .long() ) padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2) - padded_embed = padded_embed + self.positional_embedding.positional_embedding[ : padded_embed.shape[1], : ].unsqueeze(0).to(padded_embed.dtype) - hidden_states = padded_embed[padded_mask_after_cnn] - cu_seqlens = torch.cat( - ( - torch.zeros(1, device=padded_mask_after_cnn.device, dtype=torch.int32), - padded_mask_after_cnn.sum(1).cumsum(0), - ) - ).to(torch.int32) - attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens) + hidden_states = torch.index_select(padded_embed.reshape(-1, padded_embed.shape[-1]), 0, valid_indices) for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, cu_seqlens=cu_seqlens, - attention_mask=attention_mask, **kwargs, ) hidden_states = layer_outputs[0] - hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) - token_audio_list = [] - for each_audio_states in hidden_states_list: - each_audio_states = self.avg_pooler(each_audio_states.transpose(0, 1)).transpose_(0, 1) - each_audio_states = self.ln_post(each_audio_states) - each_audio_states = self.proj(each_audio_states) - token_audio_list.append(each_audio_states) - token_audio = torch.cat(token_audio_list, dim=0) - return BaseModelOutputWithPooling(last_hidden_state=token_audio) + # Post-process: stride-2 average pooling using precomputed indices, then project + hidden_states = (hidden_states[pool_indices] + hidden_states[pool_indices + 1]) / 2 + hidden_states = self.proj(self.ln_post(hidden_states)) + return BaseModelOutputWithPooling(last_hidden_state=hidden_states) + + # Ignore copy + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): """ Pads a sequence of tensors to their maximum length on indicated `padding_side`. Then prepares a mask so that pad tokens are not attended to. """ + warnings.warn( + f"`{self.__class__.__name__}.padded_and_mask_function` is deprecated and will be removed in a future version. Use `chunk_and_pad_features` and `get_audio_cu_seqlens` helpers instead.", + FutureWarning, + stacklevel=2, + ) max_len = tensor_len.max() dim = tensor_list[0].shape[0] padded_tensor = torch.full( @@ -877,15 +981,6 @@ def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, pad batch_mask_after_cnn.bool(), ) - # Ignore copy - def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): - """ - Computes the output length of the convolutional layers and the output length of the audio encoder - """ - input_lengths = (input_lengths - 1) // 2 + 1 - output_lengths = (input_lengths - 2) // 2 + 1 - return input_lengths, output_lengths - def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -1067,10 +1162,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Qwen2_5_VisionPatchEmbed(nn.Module): @@ -1153,75 +1246,29 @@ def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> self.post_init() def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw.tolist(): - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) return rotary_pos_emb def get_window_index(self, grid_thw): - window_index: list = [] - cu_window_seqlens: list = [0] - window_index_id = 0 - vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size - grid_thw_list = grid_thw.tolist() - - for grid_t, grid_h, grid_w in grid_thw_list: - llm_grid_h, llm_grid_w = ( - grid_h // self.spatial_merge_size, - grid_w // self.spatial_merge_size, - ) - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) - pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size - pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size - num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size - num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) - index_padded = index_padded.reshape( - grid_t, - num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size, - ) - index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, - num_windows_h * num_windows_w, - vit_merger_window_size, - vit_merger_window_size, - ) - seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) - index_padded = index_padded.reshape(-1) - index_new = index_padded[index_padded != -100] - window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] - cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) - window_index_id += grid_t * llm_grid_h * llm_grid_w - window_index = torch.cat(window_index, dim=0) - - return window_index, cu_window_seqlens + warnings.warn( + f"`{self.__class__.__name__}.get_window_index` is deprecated and will be removed in a future version. Use `get_vision_window_index` from `transformers.vision_utils` instead.", + FutureWarning, + stacklevel=2, + ) + window_index, cu_window_seqlens = get_vision_window_index( + grid_thw, + self.spatial_merge_size, + self.window_size, + self.patch_size, + self.spatial_merge_unit, + ) + return window_index, cu_window_seqlens.tolist() @merge_with_config_defaults @capture_outputs @@ -1238,35 +1285,30 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) - - window_index, cu_window_seqlens = self.get_window_index(grid_thw) - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - device=hidden_states.device, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + window_index, cu_window_seqlens = get_vision_window_index( + grid_thw, + self.spatial_merge_size, + self.window_size, + self.patch_size, + self.spatial_merge_unit, + kwargs=kwargs, ) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + hidden_states = self.patch_embed(hidden_states) seq_len, _ = hidden_states.size() + reverse_indices = torch.argsort(window_index) hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) + + rotary_pos_emb = self.rotary_pos_emb(position_ids) rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - # Modification here for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: @@ -1282,7 +1324,6 @@ def forward( ) merged_hidden_states = self.merger(hidden_states) - reverse_indices = torch.argsort(window_index) merged_hidden_states = merged_hidden_states[reverse_indices, :] return BaseModelOutputWithPooling( @@ -1713,6 +1754,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -1728,8 +1770,9 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + return self.visual(pixel_values_videos, grid_thw=video_grid_thw, return_dict=True, **kwargs) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -1745,7 +1788,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - return self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + return self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) @can_return_tuple @auto_docstring @@ -1775,11 +1818,7 @@ def get_audio_features( ) feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) audio_outputs = self.audio_tower( - input_features, - feature_lens=feature_lens, - aftercnn_lens=audio_feat_lengths, - return_dict=True, - **kwargs, + input_features, feature_lens=feature_lens, aftercnn_lens=audio_feat_lengths, **kwargs ) if audio_outputs.last_hidden_state.shape[0] != sum(audio_output_lengths.tolist()): raise ValueError("length of audio_features should match audio_output_lengths") @@ -1922,17 +1961,16 @@ def forward( # 2. Merge text , audios , image and video if input_features is not None: audio_features = self.get_audio_features( - input_features, - feature_attention_mask=feature_attention_mask, - audio_feature_lengths=audio_feature_lengths, - return_dict=True, + input_features, feature_attention_mask, audio_feature_lengths, return_dict=True, **kwargs ).last_hidden_state audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) _, _, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds @@ -1940,7 +1978,9 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs + ).pooler_output video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 4618b08cd574..6d885813acef 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -15,6 +15,7 @@ """PyTorch Qwen2.5Omni model (Audio, Image, Video).""" import math +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -43,9 +44,10 @@ torch_compilable_check, ) from ...utils.deprecation import deprecate_kwarg -from ...utils.generic import is_flash_attention_requested, merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, is_flash_attention_requested, merge_with_config_defaults from ...utils.hub import cached_file from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids, get_vision_window_index from ..llama.modeling_llama import LlamaRotaryEmbedding, rotate_half from ..qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig from ..qwen2_5_vl.modeling_qwen2_5_vl import ( @@ -66,6 +68,97 @@ logger = logging.get_logger(__name__) +def chunk_and_pad_features( + input_features: torch.Tensor, feature_lens: torch.Tensor, n_window: int, *, kwargs: dict | None = None +) -> tuple[torch.Tensor, torch.Tensor]: + """Split audio features into fixed-size chunks and pad to uniform length, or pop precomputed pair from ``kwargs``. + + Each audio sample is split into chunks of ``n_window * 2`` frames (the last + chunk may be shorter), then all chunks are right-padded to the longest chunk. + + Args: + input_features: ``(feature_dim, total_frames)`` concatenated audio features. + feature_lens: ``(batch_size,)`` per-sample frame counts. + n_window: half the target chunk size in frames. + kwargs: optional caller kwargs — if it contains both ``"padded_feature"`` and ``"chunk_lengths"`` they are popped and returned. + + Returns: + ``padded_feature``: ``(num_chunks, feature_dim, max_chunk_len)`` padded chunks. + ``chunk_lengths``: ``(num_chunks,)`` actual length of each chunk before padding. + """ + if kwargs is not None: + padded_feature = kwargs.pop("padded_feature", None) + chunk_lengths = kwargs.pop("chunk_lengths", None) + if padded_feature is not None and chunk_lengths is not None: + return padded_feature, chunk_lengths + chunk_num = torch.ceil(feature_lens / (n_window * 2)).long() + chunk_lengths = torch.full((chunk_num.sum(),), n_window * 2, dtype=torch.long, device=feature_lens.device) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (n_window * 2) + chunk_lengths = torch.where(chunk_lengths == 0, n_window * 2, chunk_lengths) + chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) + padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) + return padded_feature, chunk_lengths + + +def get_audio_cu_seqlens(chunk_lengths: torch.Tensor, *, kwargs: dict | None = None) -> torch.Tensor: + """Compute cumulative sequence lengths for audio attention, or pop ``"cu_seqlens"`` from ``kwargs`` if precomputed. + + Applies one stride-2 convolution length reduction, then returns cumulative + boundaries for flash-attention-style sequence packing. + + Args: + chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + kwargs: optional caller kwargs — if it contains ``"cu_seqlens"`` it is popped and returned. + + Returns: + ``(num_chunks + 1,)`` int32 cumulative sequence boundaries. + """ + if kwargs is not None and (cu_seqlens := kwargs.pop("cu_seqlens", None)) is not None: + return cu_seqlens + after_conv1 = (chunk_lengths - 1) // 2 + 1 + return F.pad(after_conv1.cumsum(0), (1, 0), value=0).to(torch.int32) + + +def get_valid_indices(chunk_lengths: torch.Tensor, *, kwargs: dict | None = None) -> torch.Tensor: + """Compute flat indices of valid (non-padding) positions after one stride-2 conv, or pop ``"valid_indices"`` from ``kwargs`` if precomputed. + + Args: + chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + kwargs: optional caller kwargs — if it contains ``"valid_indices"`` it is popped and returned. + + Returns: + ``(total_valid,)`` flat indices into the ``(num_chunks * max_len_after_conv)`` grid. + """ + if kwargs is not None and (valid_indices := kwargs.pop("valid_indices", None)) is not None: + return valid_indices + after_conv1 = (chunk_lengths - 1) // 2 + 1 + max_len = after_conv1.max().item() + mask = torch.arange(max_len, device=chunk_lengths.device) < after_conv1.unsqueeze(1) + return mask.flatten().nonzero().squeeze(-1) + + +def get_pool_indices(feature_lens: torch.Tensor, *, kwargs: dict | None = None) -> torch.Tensor: + """Compute indices for post-encoder stride-2 average pooling, or pop ``"pool_indices"`` from ``kwargs`` if precomputed. + + Args: + feature_lens: ``(batch_size,)`` mel spectrogram lengths. + kwargs: optional caller kwargs — if it contains ``"pool_indices"`` it is popped and returned. + + Returns: + ``(total_pooled,)`` flat index of first element of each stride-2 pair. + """ + if kwargs is not None and (pool_indices := kwargs.pop("pool_indices", None)) is not None: + return pool_indices + after_conv1 = (feature_lens - 1) // 2 + 1 + num_pooled = (after_conv1 - 2) // 2 + 1 + offsets = F.pad(after_conv1[:-1].cumsum(0), (1, 0), value=0) + pair_offsets = torch.repeat_interleave(offsets, num_pooled) + local_indices = torch.arange(num_pooled.sum(), device=feature_lens.device) + local_indices -= torch.repeat_interleave(F.pad(num_pooled[:-1].cumsum(0), (1, 0), value=0), num_pooled) + return pair_offsets + local_indices * 2 + + @auto_docstring(checkpoint="Qwen/Qwen2.5-Omni-7B") @strict class Qwen2_5OmniVisionEncoderConfig(Qwen2_5_VLVisionConfig): @@ -1097,10 +1190,9 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, + cu_seqlens: torch.Tensor, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + ) -> torch.Tensor: """Input shape: Batch x Time x Channel""" seq_length, _ = hidden_states.size() @@ -1112,27 +1204,50 @@ def forward( query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seq_lens_k=cu_seqlens, - max_length_q=max_seqlen, - max_length_k=max_seqlen, - is_causal=False, - **kwargs, - ) + if is_flash_attention_requested(self.config): + # Flash Attention: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.out_proj(attn_output) @@ -1149,7 +1264,6 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - attention_mask: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: residual = hidden_states @@ -1157,7 +1271,6 @@ def forward( hidden_states = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, - attention_mask=attention_mask, **kwargs, ) hidden_states = residual + hidden_states @@ -1247,85 +1360,71 @@ def get_input_embeddings(self) -> nn.Module: def set_input_embeddings(self, value: nn.Module): self.conv1 = value - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if is_flash_attention_requested(self.config): - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - @merge_with_config_defaults @capture_outputs(tie_last_hidden_states=False) @auto_docstring - def forward(self, input_features, feature_lens=None, aftercnn_lens=None, **kwargs: Unpack[TransformersKwargs]): + def forward( + self, input_features=None, feature_lens=None, aftercnn_lens=None, **kwargs: Unpack[TransformersKwargs] + ): r""" feature_lens (`torch.LongTensor` of shape `(batch_size,)`): mel length - aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`): + aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`, *optional*): mel length after cnn """ - chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() - - chunk_lengths = torch.full((chunk_num.sum(),), self.n_window * 2, dtype=torch.long, device=feature_lens.device) - tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] - chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) - chunk_lengths = torch.where(chunk_lengths == 0, self.n_window * 2, chunk_lengths) + padded_feature, chunk_lengths = chunk_and_pad_features( + input_features, feature_lens, self.n_window, kwargs=kwargs + ) + valid_indices = get_valid_indices(chunk_lengths, kwargs=kwargs) + pool_indices = get_pool_indices(feature_lens, kwargs=kwargs) + cu_seqlens = get_audio_cu_seqlens(chunk_lengths, kwargs=kwargs) - chunk_list = input_features.split(chunk_lengths.tolist(), dim=1) - padded_feature, padded_mask, padded_mask_after_cnn = self.padded_and_mask_function( - chunk_list, chunk_lengths, padding_value=0, padding_side="right" + # Derive masks from chunk_lengths (traceable arithmetic + arange broadcasting) + padded_feature = padded_feature.to(self.conv1.weight.dtype) + padded_mask = ( + (torch.arange(padded_feature.shape[2], device=padded_feature.device) < chunk_lengths.unsqueeze(1)) + .unsqueeze(1) + .long() ) padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2) - padded_embed = padded_embed + self.positional_embedding.positional_embedding[ : padded_embed.shape[1], : ].unsqueeze(0).to(padded_embed.dtype) - hidden_states = padded_embed[padded_mask_after_cnn] - cu_seqlens = torch.cat( - ( - torch.zeros(1, device=padded_mask_after_cnn.device, dtype=torch.int32), - padded_mask_after_cnn.sum(1).cumsum(0), - ) - ).to(torch.int32) - attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens) + hidden_states = torch.index_select(padded_embed.reshape(-1, padded_embed.shape[-1]), 0, valid_indices) for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, cu_seqlens=cu_seqlens, - attention_mask=attention_mask, **kwargs, ) hidden_states = layer_outputs[0] - hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) - token_audio_list = [] - for each_audio_states in hidden_states_list: - each_audio_states = self.avg_pooler(each_audio_states.transpose(0, 1)).transpose_(0, 1) - each_audio_states = self.ln_post(each_audio_states) - each_audio_states = self.proj(each_audio_states) - token_audio_list.append(each_audio_states) - token_audio = torch.cat(token_audio_list, dim=0) - return BaseModelOutputWithPooling(last_hidden_state=token_audio) + # Post-process: stride-2 average pooling using precomputed indices, then project + hidden_states = (hidden_states[pool_indices] + hidden_states[pool_indices + 1]) / 2 + hidden_states = self.proj(self.ln_post(hidden_states)) + return BaseModelOutputWithPooling(last_hidden_state=hidden_states) + + # Ignore copy + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): """ Pads a sequence of tensors to their maximum length on indicated `padding_side`. Then prepares a mask so that pad tokens are not attended to. """ + warnings.warn( + f"`{self.__class__.__name__}.padded_and_mask_function` is deprecated and will be removed in a future version. Use `chunk_and_pad_features` and `get_audio_cu_seqlens` helpers instead.", + FutureWarning, + stacklevel=2, + ) max_len = tensor_len.max() dim = tensor_list[0].shape[0] padded_tensor = torch.full( @@ -1359,15 +1458,6 @@ def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, pad batch_mask_after_cnn.bool(), ) - # Ignore copy - def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): - """ - Computes the output length of the convolutional layers and the output length of the audio encoder - """ - input_lengths = (input_lengths - 1) // 2 + 1 - output_lengths = (input_lengths - 2) // 2 + 1 - return input_lengths, output_lengths - def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: orig_dtype = tensor.dtype @@ -1520,35 +1610,30 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) - - window_index, cu_window_seqlens = self.get_window_index(grid_thw) - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - device=hidden_states.device, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + window_index, cu_window_seqlens = get_vision_window_index( + grid_thw, + self.spatial_merge_size, + self.window_size, + self.patch_size, + self.spatial_merge_unit, + kwargs=kwargs, ) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + hidden_states = self.patch_embed(hidden_states) seq_len, _ = hidden_states.size() + reverse_indices = torch.argsort(window_index) hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) + + rotary_pos_emb = self.rotary_pos_emb(position_ids) rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - # Modification here for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: @@ -1564,7 +1649,6 @@ def forward( ) merged_hidden_states = self.merger(hidden_states) - reverse_indices = torch.argsort(window_index) merged_hidden_states = merged_hidden_states[reverse_indices, :] return BaseModelOutputWithPooling( @@ -1650,6 +1734,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -1665,8 +1750,9 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + return self.visual(pixel_values_videos, grid_thw=video_grid_thw, return_dict=True, **kwargs) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -1682,7 +1768,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - return self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + return self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) @can_return_tuple @auto_docstring @@ -1712,11 +1798,7 @@ def get_audio_features( ) feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) audio_outputs = self.audio_tower( - input_features, - feature_lens=feature_lens, - aftercnn_lens=audio_feat_lengths, - return_dict=True, - **kwargs, + input_features, feature_lens=feature_lens, aftercnn_lens=audio_feat_lengths, **kwargs ) if audio_outputs.last_hidden_state.shape[0] != sum(audio_output_lengths.tolist()): raise ValueError("length of audio_features should match audio_output_lengths") @@ -1859,17 +1941,16 @@ def forward( # 2. Merge text , audios , image and video if input_features is not None: audio_features = self.get_audio_features( - input_features, - feature_attention_mask=feature_attention_mask, - audio_feature_lengths=audio_feature_lengths, - return_dict=True, + input_features, feature_attention_mask, audio_feature_lengths, return_dict=True, **kwargs ).last_hidden_state audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) _, _, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output + image_embeds = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True, **kwargs + ).pooler_output image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds @@ -1877,7 +1958,9 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output + video_embeds = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs + ).pooler_output video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 9e2812720d4c..ea8702b69e31 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -24,13 +24,13 @@ # limitations under the License. import itertools +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional import torch import torch.nn as nn -import torch.nn.functional as F from ... import initialization as init from ...activations import ACT2FN @@ -45,8 +45,14 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids, get_vision_window_index from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig @@ -124,10 +130,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Qwen2_5_VLPatchMerger(nn.Module): @@ -380,75 +384,29 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw.tolist(): - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) return rotary_pos_emb def get_window_index(self, grid_thw): - window_index: list = [] - cu_window_seqlens: list = [0] - window_index_id = 0 - vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size - grid_thw_list = grid_thw.tolist() - - for grid_t, grid_h, grid_w in grid_thw_list: - llm_grid_h, llm_grid_w = ( - grid_h // self.spatial_merge_size, - grid_w // self.spatial_merge_size, - ) - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) - pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size - pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size - num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size - num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) - index_padded = index_padded.reshape( - grid_t, - num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size, - ) - index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, - num_windows_h * num_windows_w, - vit_merger_window_size, - vit_merger_window_size, - ) - seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) - index_padded = index_padded.reshape(-1) - index_new = index_padded[index_padded != -100] - window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] - cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) - window_index_id += grid_t * llm_grid_h * llm_grid_w - window_index = torch.cat(window_index, dim=0) - - return window_index, cu_window_seqlens + warnings.warn( + f"`{self.__class__.__name__}.get_window_index` is deprecated and will be removed in a future version. Use `get_vision_window_index` from `transformers.vision_utils` instead.", + FutureWarning, + stacklevel=2, + ) + window_index, cu_window_seqlens = get_vision_window_index( + grid_thw, + self.spatial_merge_size, + self.window_size, + self.patch_size, + self.spatial_merge_unit, + ) + return window_index, cu_window_seqlens.tolist() @merge_with_config_defaults @capture_outputs @@ -465,36 +423,31 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) - window_index, cu_window_seqlens = self.get_window_index(grid_thw) - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - device=hidden_states.device, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + window_index, cu_window_seqlens = get_vision_window_index( + grid_thw, + self.spatial_merge_size, + self.window_size, + self.patch_size, + self.spatial_merge_unit, + kwargs=kwargs, ) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + hidden_states = self.patch_embed(hidden_states) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) + + rotary_pos_emb = self.rotary_pos_emb(position_ids) rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens @@ -508,8 +461,8 @@ def forward( **kwargs, ) - merged_hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) + merged_hidden_states = self.merger(hidden_states) merged_hidden_states = merged_hidden_states[reverse_indices, :] return BaseModelOutputWithPooling( @@ -1139,6 +1092,7 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -1154,13 +1108,14 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values_videos, grid_thw=video_grid_thw, return_dict=True, **kwargs) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds return vision_outputs + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -1176,7 +1131,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds @@ -1306,7 +1261,7 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw).pooler_output + image_embeds = self.get_image_features(pixel_values, image_grid_thw, **kwargs).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds @@ -1314,7 +1269,7 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw).pooler_output + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, **kwargs).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds @@ -1399,6 +1354,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @auto_docstring def get_video_features( self, @@ -1412,10 +1368,9 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @auto_docstring def get_image_features( self, @@ -1429,7 +1384,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 038209892b6d..6a3dd95ae216 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -19,10 +19,10 @@ """PyTorch Qwen2.5-VL model.""" import itertools +import warnings import torch import torch.nn as nn -import torch.nn.functional as F from huggingface_hub.dataclasses import strict from ... import initialization as init @@ -40,6 +40,7 @@ from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ...video_utils import VideoInput +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids, get_vision_window_index from ..llama.modeling_llama import LlamaRMSNorm from ..qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig from ..qwen2_vl.modeling_qwen2_vl import ( @@ -219,75 +220,29 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw.tolist(): - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) return rotary_pos_emb def get_window_index(self, grid_thw): - window_index: list = [] - cu_window_seqlens: list = [0] - window_index_id = 0 - vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size - grid_thw_list = grid_thw.tolist() - - for grid_t, grid_h, grid_w in grid_thw_list: - llm_grid_h, llm_grid_w = ( - grid_h // self.spatial_merge_size, - grid_w // self.spatial_merge_size, - ) - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) - pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size - pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size - num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size - num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) - index_padded = index_padded.reshape( - grid_t, - num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size, - ) - index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, - num_windows_h * num_windows_w, - vit_merger_window_size, - vit_merger_window_size, - ) - seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) - index_padded = index_padded.reshape(-1) - index_new = index_padded[index_padded != -100] - window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] - cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) - window_index_id += grid_t * llm_grid_h * llm_grid_w - window_index = torch.cat(window_index, dim=0) - - return window_index, cu_window_seqlens + warnings.warn( + f"`{self.__class__.__name__}.get_window_index` is deprecated and will be removed in a future version. Use `get_vision_window_index` from `transformers.vision_utils` instead.", + FutureWarning, + stacklevel=2, + ) + window_index, cu_window_seqlens = get_vision_window_index( + grid_thw, + self.spatial_merge_size, + self.window_size, + self.patch_size, + self.spatial_merge_unit, + ) + return window_index, cu_window_seqlens.tolist() @merge_with_config_defaults @capture_outputs @@ -304,36 +259,31 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) - window_index, cu_window_seqlens = self.get_window_index(grid_thw) - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - device=hidden_states.device, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + window_index, cu_window_seqlens = get_vision_window_index( + grid_thw, + self.spatial_merge_size, + self.window_size, + self.patch_size, + self.spatial_merge_unit, + kwargs=kwargs, ) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + hidden_states = self.patch_embed(hidden_states) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) + + rotary_pos_emb = self.rotary_pos_emb(position_ids) rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens @@ -347,8 +297,8 @@ def forward( **kwargs, ) - merged_hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) + merged_hidden_states = self.merger(hidden_states) merged_hidden_states = merged_hidden_states[reverse_indices, :] return BaseModelOutputWithPooling( @@ -570,7 +520,7 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw).pooler_output + image_embeds = self.get_image_features(pixel_values, image_grid_thw, **kwargs).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds @@ -578,7 +528,7 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw).pooler_output + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, **kwargs).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 7ea940df2ae0..9c1c69d97248 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -19,13 +19,13 @@ """PyTorch Qwen2-VL model.""" import itertools +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional import torch import torch.nn as nn -import torch.nn.functional as F from torch.nn import LayerNorm from ... import initialization as init @@ -48,11 +48,13 @@ torch_compilable_check, ) from ...utils.generic import ( + handle_extra_kwargs, is_flash_attention_requested, maybe_autocast, merge_with_config_defaults, ) from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig, Qwen2VLVisionConfig @@ -278,10 +280,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class PatchEmbed(nn.Module): @@ -723,62 +723,33 @@ def get_device(self) -> torch.device: return self.blocks[0].mlp.fc2.weight.device def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) return rotary_pos_emb @merge_with_config_defaults @capture_outputs @auto_docstring def forward( - self, - hidden_states: torch.Tensor, - grid_thw: torch.Tensor, - **kwargs: Unpack[TransformersKwargs], + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs] ) -> torch.Tensor: r""" grid_thw (`torch.LongTensor` of shape `(num_images, 3)`): The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values. """ + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) + hidden_states = self.patch_embed(hidden_states) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = self.rotary_pos_emb(position_ids) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - for blk in self.blocks: hidden_states = blk( hidden_states, @@ -1093,6 +1064,7 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -1108,13 +1080,14 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values_videos, grid_thw=video_grid_thw, return_dict=True, **kwargs) split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() video_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = video_embeds return vision_outputs + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -1130,7 +1103,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(vision_outputs.pooler_output, split_sizes) vision_outputs.pooler_output = image_embeds @@ -1258,7 +1231,7 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - image_embeds = self.get_image_features(pixel_values, image_grid_thw).pooler_output + image_embeds = self.get_image_features(pixel_values, image_grid_thw, **kwargs).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds @@ -1266,7 +1239,7 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: - video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw).pooler_output + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, **kwargs).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds @@ -1318,6 +1291,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @auto_docstring def get_video_features( self, @@ -1331,10 +1305,9 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @auto_docstring def get_image_features( self, @@ -1348,7 +1321,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index bad700952673..194b49053baf 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -19,6 +19,7 @@ # limitations under the License. import itertools +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -45,9 +46,15 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids from .configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig, Qwen3_5VisionConfig @@ -76,10 +83,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Qwen3_5TextRotaryEmbedding(nn.Module): @@ -1039,107 +1044,25 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - merge_size = self.spatial_merge_size - grid_thw_list = grid_thw.tolist() - - max_hw = max(max(h, w) for _, h, w in grid_thw_list) - freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) - device = freq_table.device - - total_tokens = sum(t * h * w for t, h, w in grid_thw_list) - pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) - - offset = 0 - for num_frames, height, width in grid_thw_list: - merged_h, merged_w = height // merge_size, width // merge_size - - block_rows = torch.arange(merged_h, device=device) # block row indices - block_cols = torch.arange(merged_w, device=device) # block col indices - intra_row = torch.arange(merge_size, device=device) # intra-block row offsets - intra_col = torch.arange(merge_size, device=device) # intra-block col offsets - - # Compute full-resolution positions - row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] - col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] - - row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - - coords = torch.stack((row_idx, col_idx), dim=-1) - - if num_frames > 1: - coords = coords.repeat(num_frames, 1) - - num_tokens = coords.shape[0] - pos_ids[offset : offset + num_tokens] = coords - offset += num_tokens - - embeddings = freq_table[pos_ids] # lookup rotary embeddings - embeddings = embeddings.flatten(1) - return embeddings + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb def fast_pos_embed_interpolate(self, grid_thw): - grid_thw_list = grid_thw.tolist() - grid_ts = [row[0] for row in grid_thw_list] - grid_hs = [row[1] for row in grid_thw_list] - grid_ws = [row[2] for row in grid_thw_list] - device = self.pos_embed.weight.device - - idx_list = [[] for _ in range(4)] - weight_list = [[] for _ in range(4)] - - for t, h, w in grid_thw_list: - h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) - w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) - - h_idxs_floor = h_idxs.int() - w_idxs_floor = w_idxs.int() - h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - - dh = h_idxs - h_idxs_floor - dw = w_idxs - w_idxs_floor - - base_h = h_idxs_floor * self.num_grid_per_side - base_h_ceil = h_idxs_ceil * self.num_grid_per_side - - indices = [ - (base_h[None].T + w_idxs_floor[None]).flatten(), - (base_h[None].T + w_idxs_ceil[None]).flatten(), - (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), - (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), - ] - - weights = [ - ((1 - dh)[None].T * (1 - dw)[None]).flatten(), - ((1 - dh)[None].T * dw[None]).flatten(), - (dh[None].T * (1 - dw)[None]).flatten(), - (dh[None].T * dw[None]).flatten(), - ] - - for i in range(4): - idx_list[i].extend(indices[i].tolist()) - weight_list[i].extend(weights[i].tolist()) - - idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) - weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) - pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] - patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] - - patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) - - patch_pos_embeds_permute = [] - merge_size = self.config.spatial_merge_size - for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): - pos_embed = pos_embed.repeat(t, 1) - pos_embed = ( - pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) - .permute(0, 1, 3, 2, 4, 5) - .flatten(0, 4) - ) - patch_pos_embeds_permute.append(pos_embed) - patch_pos_embeds = torch.cat(patch_pos_embeds_permute) - return patch_pos_embeds + warnings.warn( + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", + FutureWarning, + stacklevel=2, + ) + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) @merge_with_config_defaults @capture_outputs @@ -1154,12 +1077,19 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) - hidden_states = hidden_states + pos_embeds + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, + self.num_grid_per_side, + self.config.spatial_merge_size, + kwargs=kwargs, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + hidden_states = self.patch_embed(hidden_states) + pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) + hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(position_ids) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) @@ -1167,16 +1097,6 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - for blk in self.blocks: hidden_states = blk( hidden_states, @@ -1489,6 +1409,7 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -1506,6 +1427,7 @@ def get_video_features( # Same implementation as for images return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -1651,7 +1573,7 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithPooling = self.get_image_features( - pixel_values, image_grid_thw, return_dict=True + pixel_values, image_grid_thw, return_dict=True, **kwargs ) image_embeds = image_outputs.pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) @@ -1662,7 +1584,7 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithPooling = self.get_video_features( - pixel_values_videos, video_grid_thw, return_dict=True + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs ) video_embeds = video_outputs.pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) @@ -1843,9 +1765,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) @auto_docstring def get_image_features( @@ -1860,7 +1780,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple def forward( diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index 710b63a28dba..6226893a1b2c 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -28,8 +28,9 @@ from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids from ..qwen3.modeling_qwen3 import Qwen3ForCausalLM from ..qwen3_next.configuration_qwen3_next import Qwen3NextConfig from ..qwen3_next.modeling_qwen3_next import ( @@ -440,12 +441,19 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) - hidden_states = hidden_states + pos_embeds + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, + self.num_grid_per_side, + self.config.spatial_merge_size, + kwargs=kwargs, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + hidden_states = self.patch_embed(hidden_states) + pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) + hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(position_ids) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) @@ -453,16 +461,6 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - for blk in self.blocks: hidden_states = blk( hidden_states, @@ -559,13 +557,13 @@ def forward( class Qwen3_5Model(Qwen3VLModel): _no_split_modules = ["Qwen3_5DecoderLayer", "Qwen3_5VisionBlock"] - def get_video_features( - self, - **super_kwargs, - ) -> tuple | BaseModelOutputWithPooling: + def get_video_features(self, **super_kwargs) -> tuple | BaseModelOutputWithPooling: # Same implementation as for images return super().get_video_features(**super_kwargs) + @handle_extra_kwargs(modality="image") + @can_return_tuple + @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, @@ -613,7 +611,7 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithPooling = self.get_image_features( - pixel_values, image_grid_thw, return_dict=True + pixel_values, image_grid_thw, return_dict=True, **kwargs ) image_embeds = image_outputs.pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) @@ -624,7 +622,7 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithPooling = self.get_video_features( - pixel_values_videos, video_grid_thw, return_dict=True + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs ) video_embeds = video_outputs.pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) @@ -673,16 +671,10 @@ class Qwen3_5ForSequenceClassification(GenericForSequenceClassification, Qwen3_5 class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration): - def get_video_features( - self, - **super_kwargs, - ) -> tuple | BaseModelOutputWithPooling: + def get_video_features(self, **super_kwargs) -> tuple | BaseModelOutputWithPooling: return super().get_video_features(**super_kwargs) - def get_image_features( - self, - **super_kwargs, - ) -> tuple | BaseModelOutputWithPooling: + def get_image_features(self, **super_kwargs) -> tuple | BaseModelOutputWithPooling: return super().get_image_features(**super_kwargs) diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index d7b45a276412..d634ffb8783b 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -19,6 +19,7 @@ # limitations under the License. import itertools +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -46,9 +47,15 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available from ...utils.output_capturing import OutputRecorder, capture_outputs +from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids from .configuration_qwen3_5_moe import Qwen3_5MoeConfig, Qwen3_5MoeTextConfig, Qwen3_5MoeVisionConfig @@ -77,10 +84,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Qwen3_5MoeTextRotaryEmbedding(nn.Module): @@ -1132,107 +1137,25 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - merge_size = self.spatial_merge_size - grid_thw_list = grid_thw.tolist() - - max_hw = max(max(h, w) for _, h, w in grid_thw_list) - freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) - device = freq_table.device - - total_tokens = sum(t * h * w for t, h, w in grid_thw_list) - pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) - - offset = 0 - for num_frames, height, width in grid_thw_list: - merged_h, merged_w = height // merge_size, width // merge_size - - block_rows = torch.arange(merged_h, device=device) # block row indices - block_cols = torch.arange(merged_w, device=device) # block col indices - intra_row = torch.arange(merge_size, device=device) # intra-block row offsets - intra_col = torch.arange(merge_size, device=device) # intra-block col offsets - - # Compute full-resolution positions - row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] - col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] - - row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - - coords = torch.stack((row_idx, col_idx), dim=-1) - - if num_frames > 1: - coords = coords.repeat(num_frames, 1) - - num_tokens = coords.shape[0] - pos_ids[offset : offset + num_tokens] = coords - offset += num_tokens - - embeddings = freq_table[pos_ids] # lookup rotary embeddings - embeddings = embeddings.flatten(1) - return embeddings + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb def fast_pos_embed_interpolate(self, grid_thw): - grid_thw_list = grid_thw.tolist() - grid_ts = [row[0] for row in grid_thw_list] - grid_hs = [row[1] for row in grid_thw_list] - grid_ws = [row[2] for row in grid_thw_list] - device = self.pos_embed.weight.device - - idx_list = [[] for _ in range(4)] - weight_list = [[] for _ in range(4)] - - for t, h, w in grid_thw_list: - h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) - w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) - - h_idxs_floor = h_idxs.int() - w_idxs_floor = w_idxs.int() - h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - - dh = h_idxs - h_idxs_floor - dw = w_idxs - w_idxs_floor - - base_h = h_idxs_floor * self.num_grid_per_side - base_h_ceil = h_idxs_ceil * self.num_grid_per_side - - indices = [ - (base_h[None].T + w_idxs_floor[None]).flatten(), - (base_h[None].T + w_idxs_ceil[None]).flatten(), - (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), - (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), - ] - - weights = [ - ((1 - dh)[None].T * (1 - dw)[None]).flatten(), - ((1 - dh)[None].T * dw[None]).flatten(), - (dh[None].T * (1 - dw)[None]).flatten(), - (dh[None].T * dw[None]).flatten(), - ] - - for i in range(4): - idx_list[i].extend(indices[i].tolist()) - weight_list[i].extend(weights[i].tolist()) - - idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) - weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) - pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] - patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] - - patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) - - patch_pos_embeds_permute = [] - merge_size = self.config.spatial_merge_size - for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): - pos_embed = pos_embed.repeat(t, 1) - pos_embed = ( - pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) - .permute(0, 1, 3, 2, 4, 5) - .flatten(0, 4) - ) - patch_pos_embeds_permute.append(pos_embed) - patch_pos_embeds = torch.cat(patch_pos_embeds_permute) - return patch_pos_embeds + warnings.warn( + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", + FutureWarning, + stacklevel=2, + ) + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) @merge_with_config_defaults @capture_outputs @@ -1247,12 +1170,19 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) - hidden_states = hidden_states + pos_embeds + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, + self.num_grid_per_side, + self.config.spatial_merge_size, + kwargs=kwargs, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + hidden_states = self.patch_embed(hidden_states) + pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) + hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(position_ids) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) @@ -1260,16 +1190,6 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - for blk in self.blocks: hidden_states = blk( hidden_states, @@ -1614,6 +1534,7 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -1631,6 +1552,7 @@ def get_video_features( # Same implementation as for images return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -1776,7 +1698,7 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithPooling = self.get_image_features( - pixel_values, image_grid_thw, return_dict=True + pixel_values, image_grid_thw, return_dict=True, **kwargs ) image_embeds = image_outputs.pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) @@ -1787,7 +1709,7 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithPooling = self.get_video_features( - pixel_values_videos, video_grid_thw, return_dict=True + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs ) video_embeds = video_outputs.pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) @@ -2042,9 +1964,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) @auto_docstring def get_image_features( @@ -2059,7 +1979,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple def forward( diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 7b6c8b5b1bd4..f16aec3e3f21 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -20,6 +20,7 @@ # limitations under the License. import math +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Optional @@ -56,11 +57,13 @@ from ...utils import auto_docstring, can_return_tuple, torch_compilable_check from ...utils.generic import ( TransformersKwargs, + handle_extra_kwargs, is_flash_attention_requested, maybe_autocast, merge_with_config_defaults, ) from ...utils.output_capturing import OutputRecorder, capture_outputs +from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids from .configuration_qwen3_omni_moe import ( Qwen3OmniMoeAudioEncoderConfig, Qwen3OmniMoeCode2WavConfig, @@ -522,10 +525,9 @@ def __init__(self, config): def forward( self, hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - attention_mask: torch.Tensor | None = None, + cu_seqlens: torch.Tensor, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + ) -> torch.Tensor: """Input shape: Batch x Time x Channel""" seq_length, _ = hidden_states.size() @@ -537,27 +539,50 @@ def forward( query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask=attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seq_lens_k=cu_seqlens, - max_length_q=max_seqlen, - max_length_k=max_seqlen, - is_causal=False, - **kwargs, - ) + if is_flash_attention_requested(self.config): + # Flash Attention: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.out_proj(attn_output) @@ -582,7 +607,6 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - attention_mask: torch.Tensor | None = None, **kwargs, ) -> torch.Tensor: """ @@ -596,7 +620,6 @@ def forward( hidden_states = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, - attention_mask=attention_mask, **kwargs, ) hidden_states = residual + hidden_states @@ -616,6 +639,96 @@ def forward( return outputs +def chunk_and_pad_features( + input_features: torch.Tensor, feature_lens: torch.Tensor, n_window: int, *, kwargs: dict | None = None +) -> tuple[torch.Tensor, torch.Tensor]: + """Split audio features into fixed-size chunks and pad to uniform length, or pop precomputed pair from ``kwargs``. + + Each audio sample is split into chunks of ``n_window * 2`` frames (the last + chunk may be shorter), then all chunks are right-padded to the longest chunk. + + Args: + input_features: ``(feature_dim, total_frames)`` concatenated audio features. + feature_lens: ``(batch_size,)`` per-sample frame counts. + n_window: half the target chunk size in frames. + kwargs: optional caller kwargs — if it contains both ``"padded_feature"`` and ``"chunk_lengths"`` they are popped and returned. + + Returns: + ``padded_feature``: ``(num_chunks, feature_dim, max_chunk_len)`` padded chunks. + ``chunk_lengths``: ``(num_chunks,)`` actual length of each chunk before padding. + """ + if kwargs is not None: + padded_feature = kwargs.pop("padded_feature", None) + chunk_lengths = kwargs.pop("chunk_lengths", None) + if padded_feature is not None and chunk_lengths is not None: + return padded_feature, chunk_lengths + chunk_num = torch.ceil(feature_lens / (n_window * 2)).long() + chunk_lengths = torch.full((chunk_num.sum(),), n_window * 2, dtype=torch.long, device=feature_lens.device) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (n_window * 2) + chunk_lengths = torch.where(chunk_lengths == 0, n_window * 2, chunk_lengths) + chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) + padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) + return padded_feature, chunk_lengths + + +def get_valid_indices(chunk_lengths: torch.Tensor, *, kwargs: dict | None = None) -> torch.Tensor: + """Compute flat indices of valid (non-padding) positions after CNN extraction, or pop ``"valid_indices"`` from ``kwargs`` if precomputed. + + Args: + chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + kwargs: optional caller kwargs — if it contains ``"valid_indices"`` it is popped and returned. + + Returns: + ``(total_valid,)`` flat indices into the ``(num_chunks * max_len_after_cnn)`` grid. + """ + if kwargs is not None and (valid_indices := kwargs.pop("valid_indices", None)) is not None: + return valid_indices + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + max_len_after_cnn = feature_lens_after_cnn.max().item() + mask = torch.arange(max_len_after_cnn, device=chunk_lengths.device) < feature_lens_after_cnn.unsqueeze(1) + return mask.flatten().nonzero().squeeze(-1) + + +def get_audio_cu_seqlens( + chunk_lengths: torch.Tensor, + feature_lens: torch.Tensor, + n_window_infer: int, + n_window: int, + *, + kwargs: dict | None = None, +) -> torch.Tensor: + """Compute cumulative sequence lengths for audio attention windowing, or pop ``"cu_seqlens"`` from ``kwargs`` if precomputed. + + Splits each sample's post-CNN features into inference windows and returns + cumulative boundaries for flash-attention-style sequence packing. + + Args: + chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + feature_lens: ``(batch_size,)`` per-sample frame counts. + n_window_infer: inference window size (in raw frames). + n_window: half the chunk size (in raw frames). + kwargs: optional caller kwargs — if it contains ``"cu_seqlens"`` it is popped and returned. + + Returns: + ``(num_windows + 1,)`` int32 cumulative sequence boundaries. + """ + if kwargs is not None and (cu_seqlens := kwargs.pop("cu_seqlens", None)) is not None: + return cu_seqlens + aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + max_len_after_cnn = feature_lens_after_cnn.max().item() + cu_chunk_lens = [0] + n_window_ratio = n_window_infer // (n_window * 2) + window_aftercnn = max_len_after_cnn * n_window_ratio + for cnn_len in aftercnn_lens: + cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) + remainder = cnn_len % window_aftercnn + if remainder != 0: + cu_chunk_lens += [remainder] + return torch.tensor(cu_chunk_lens, device=feature_lens.device).cumsum(-1, dtype=torch.int32) + + @auto_docstring( custom_intro=""" Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a @@ -673,56 +786,23 @@ def get_input_embeddings(self) -> nn.Module: def set_input_embeddings(self, value): self.conv2d1 = value - def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: - # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` - # NOTE: the created attention masl only approximates the ragged FA2 attention by - # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between - # blocks. Though it will not be a 100% match for FA2's `varlen` path - if is_flash_attention_requested(self.config): - return None - - seq_length = inputs_tensor.shape[0] - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(inputs_tensor.dtype).min, - device=inputs_tensor.device, - dtype=inputs_tensor.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - return attention_mask - @merge_with_config_defaults @capture_outputs(tie_last_hidden_states=False) @auto_docstring - def forward( - self, - input_features, - feature_lens=None, - aftercnn_lens=None, - **kwargs, - ): + def forward(self, input_features=None, feature_lens=None, **kwargs: Unpack[TransformersKwargs]): r""" feature_lens (`torch.LongTensor` of shape `(batch_size,)`): mel length - aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`): - mel length after cnn """ - aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) - chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() - - chunk_lengths = torch.full((chunk_num.sum(),), self.n_window * 2, dtype=torch.long, device=feature_lens.device) - tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] - chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) - chunk_lengths[chunk_lengths == 0] = self.n_window * 2 - - chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) - padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) - feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) - padded_mask_after_cnn = nn.utils.rnn.pad_sequence( - [torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn], - batch_first=True, + padded_feature, chunk_lengths = chunk_and_pad_features( + input_features, feature_lens, self.n_window, kwargs=kwargs + ) + valid_indices = get_valid_indices(chunk_lengths, kwargs=kwargs) + cu_seqlens = get_audio_cu_seqlens( + chunk_lengths, feature_lens, self.n_window_infer, self.n_window, kwargs=kwargs ) + + # Add channel dim for Conv2d: (num_chunks, mel_bins, time) -> (num_chunks, 1, mel_bins, time) padded_feature = padded_feature.unsqueeze(1) # Split to chunk to avoid OOM during convolution padded_embeds = [] @@ -732,6 +812,7 @@ def forward( padded_embed = F.gelu(self.conv2d3(padded_embed)) padded_embeds.append(padded_embed) padded_embed = torch.cat(padded_embeds, dim=0) + b, c, f, t = padded_embed.size() padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)) @@ -741,22 +822,13 @@ def forward( .to(padded_embed.dtype) ) padded_embed = padded_embed + positional_embedding - hidden_states = padded_embed[padded_mask_after_cnn] - cu_chunk_lens = [0] - window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2)) - for cnn_len in aftercnn_lens: - cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) - remainder = cnn_len % window_aftercnn - if remainder != 0: - cu_chunk_lens += [remainder] - cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32) + hidden_states = torch.index_select(padded_embed.reshape(-1, padded_embed.shape[-1]), 0, valid_indices) for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, cu_seqlens, ) - hidden_states = layer_outputs[0] hidden_states = self.ln_post(hidden_states) @@ -765,11 +837,25 @@ def forward( hidden_states = self.proj2(hidden_states) return BaseModelOutputWithPooling(last_hidden_state=hidden_states) + # Ignore copy + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths + def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): """ Pads a sequence of tensors to their maximum length on indicated `padding_side`. Then prepares a mask so that pad tokens are not attended to. """ + warnings.warn( + f"`{self.__class__.__name__}.padded_and_mask_function` is deprecated and will be removed in a future version. Use `chunk_and_pad_features` and `get_audio_cu_seqlens` helpers instead.", + FutureWarning, + stacklevel=2, + ) max_len = tensor_len.max() dim = tensor_list[0].shape[0] padded_tensor = torch.full( @@ -803,15 +889,6 @@ def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, pad batch_mask_after_cnn.bool(), ) - # Ignore copy - def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): - """ - Computes the output length of the convolutional layers and the output length of the audio encoder - """ - input_lengths = (input_lengths - 1) // 2 + 1 - output_lengths = (input_lengths - 2) // 2 + 1 - return input_lengths, output_lengths - def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -950,10 +1027,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Qwen3OmniMoeTextTopKRouter(nn.Module): @@ -1090,107 +1165,25 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - merge_size = self.spatial_merge_size - grid_thw_list = grid_thw.tolist() - - max_hw = max(max(h, w) for _, h, w in grid_thw_list) - freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) - device = freq_table.device - - total_tokens = sum(t * h * w for t, h, w in grid_thw_list) - pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) - - offset = 0 - for num_frames, height, width in grid_thw_list: - merged_h, merged_w = height // merge_size, width // merge_size - - block_rows = torch.arange(merged_h, device=device) # block row indices - block_cols = torch.arange(merged_w, device=device) # block col indices - intra_row = torch.arange(merge_size, device=device) # intra-block row offsets - intra_col = torch.arange(merge_size, device=device) # intra-block col offsets - - # Compute full-resolution positions - row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] - col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] - - row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - - coords = torch.stack((row_idx, col_idx), dim=-1) - - if num_frames > 1: - coords = coords.repeat(num_frames, 1) - - num_tokens = coords.shape[0] - pos_ids[offset : offset + num_tokens] = coords - offset += num_tokens - - embeddings = freq_table[pos_ids] # lookup rotary embeddings - embeddings = embeddings.flatten(1) - return embeddings + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb def fast_pos_embed_interpolate(self, grid_thw): - grid_thw_list = grid_thw.tolist() - grid_ts = [row[0] for row in grid_thw_list] - grid_hs = [row[1] for row in grid_thw_list] - grid_ws = [row[2] for row in grid_thw_list] - device = self.pos_embed.weight.device - - idx_list = [[] for _ in range(4)] - weight_list = [[] for _ in range(4)] - - for t, h, w in grid_thw_list: - h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) - w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) - - h_idxs_floor = h_idxs.int() - w_idxs_floor = w_idxs.int() - h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - - dh = h_idxs - h_idxs_floor - dw = w_idxs - w_idxs_floor - - base_h = h_idxs_floor * self.num_grid_per_side - base_h_ceil = h_idxs_ceil * self.num_grid_per_side - - indices = [ - (base_h[None].T + w_idxs_floor[None]).flatten(), - (base_h[None].T + w_idxs_ceil[None]).flatten(), - (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), - (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), - ] - - weights = [ - ((1 - dh)[None].T * (1 - dw)[None]).flatten(), - ((1 - dh)[None].T * dw[None]).flatten(), - (dh[None].T * (1 - dw)[None]).flatten(), - (dh[None].T * dw[None]).flatten(), - ] - - for i in range(4): - idx_list[i].extend(indices[i].tolist()) - weight_list[i].extend(weights[i].tolist()) - - idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) - weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) - pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] - patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] - - patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) - - patch_pos_embeds_permute = [] - merge_size = self.config.spatial_merge_size - for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): - pos_embed = pos_embed.repeat(t, 1) - pos_embed = ( - pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) - .permute(0, 1, 3, 2, 4, 5) - .flatten(0, 4) - ) - patch_pos_embeds_permute.append(pos_embed) - patch_pos_embeds = torch.cat(patch_pos_embeds_permute) - return patch_pos_embeds + warnings.warn( + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", + FutureWarning, + stacklevel=2, + ) + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) @merge_with_config_defaults @capture_outputs @@ -1207,12 +1200,19 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) - hidden_states = hidden_states + pos_embeds + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, + self.num_grid_per_side, + self.config.spatial_merge_size, + kwargs=kwargs, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + hidden_states = self.patch_embed(hidden_states) + pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) + hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(position_ids) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) @@ -1220,16 +1220,6 @@ def forward( emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks): hidden_states = blk( @@ -1923,6 +1913,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -1938,8 +1929,9 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + return self.visual(pixel_values_videos, grid_thw=video_grid_thw, return_dict=True, **kwargs) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -1955,7 +1947,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - return self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + return self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) @can_return_tuple @auto_docstring @@ -1981,12 +1973,7 @@ def get_audio_features( audio_feature_lengths = None feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) - audio_outputs = self.audio_tower( - input_features, - feature_lens=feature_lens, - return_dict=True, - **kwargs, - ) + audio_outputs = self.audio_tower(input_features, feature_lens=feature_lens, return_dict=True, **kwargs) return audio_outputs @@ -2133,10 +2120,7 @@ def forward( # 2. Merge text , audios , image and video if input_features is not None: audio_features = self.get_audio_features( - input_features, - feature_attention_mask=feature_attention_mask, - audio_feature_lengths=audio_feature_lengths, - return_dict=True, + input_features, feature_attention_mask, audio_feature_lengths, return_dict=True, **kwargs ).last_hidden_state audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) _, _, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) @@ -2144,7 +2128,7 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features( - pixel_values, image_grid_thw, return_dict=True + pixel_values, image_grid_thw, return_dict=True, **kwargs ) image_embeds = image_outputs.pooler_output image_embeds_multiscale = image_outputs.deepstack_features @@ -2156,7 +2140,7 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features( - pixel_values_videos, video_grid_thw, return_dict=True + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs ) video_embeds = video_outputs.pooler_output video_embeds_multiscale = video_outputs.deepstack_features diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index 2c78ad930eba..646aef61f4b0 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -46,7 +46,7 @@ from ...processing_utils import ProcessorMixin, Unpack from ...tokenization_utils_base import TextInput from ...utils import auto_docstring, can_return_tuple, logging -from ...utils.generic import TransformersKwargs, merge_with_config_defaults +from ...utils.generic import TransformersKwargs, handle_extra_kwargs, merge_with_config_defaults from ...utils.output_capturing import OutputRecorder, capture_outputs from ...video_utils import VideoInput, make_batched_videos from ..mimi.modeling_mimi import MimiLayerScale @@ -103,6 +103,107 @@ logger = logging.get_logger(__name__) +def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor: + """Compute output lengths after the 3-layer CNN feature extractor with deepstack. + + Three stride-2 convolutions within each 100-frame block, plus 13 output frames + per full block from the deepstack path. + """ + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + return ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + + +def chunk_and_pad_features( + input_features: torch.Tensor, feature_lens: torch.Tensor, n_window: int, *, kwargs: dict | None = None +) -> tuple[torch.Tensor, torch.Tensor]: + """Split audio features into fixed-size chunks and pad to uniform length, or pop precomputed pair from ``kwargs``. + + Each audio sample is split into chunks of ``n_window * 2`` frames (the last + chunk may be shorter), then all chunks are right-padded to the longest chunk. + + Args: + input_features: ``(feature_dim, total_frames)`` concatenated audio features. + feature_lens: ``(batch_size,)`` per-sample frame counts. + n_window: half the target chunk size in frames. + kwargs: optional caller kwargs — if it contains both ``"padded_feature"`` and ``"chunk_lengths"`` they are popped and returned. + + Returns: + ``padded_feature``: ``(num_chunks, feature_dim, max_chunk_len)`` padded chunks. + ``chunk_lengths``: ``(num_chunks,)`` actual length of each chunk before padding. + """ + if kwargs is not None: + padded_feature = kwargs.pop("padded_feature", None) + chunk_lengths = kwargs.pop("chunk_lengths", None) + if padded_feature is not None and chunk_lengths is not None: + return padded_feature, chunk_lengths + chunk_num = torch.ceil(feature_lens / (n_window * 2)).long() + chunk_lengths = torch.full((chunk_num.sum(),), n_window * 2, dtype=torch.long, device=feature_lens.device) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (n_window * 2) + chunk_lengths = torch.where(chunk_lengths == 0, n_window * 2, chunk_lengths) + chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) + padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) + return padded_feature, chunk_lengths + + +def get_valid_indices(chunk_lengths: torch.Tensor, *, kwargs: dict | None = None) -> torch.Tensor: + """Compute flat indices of valid (non-padding) positions after CNN extraction, or pop ``"valid_indices"`` from ``kwargs`` if precomputed. + + Args: + chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + kwargs: optional caller kwargs — if it contains ``"valid_indices"`` it is popped and returned. + + Returns: + ``(total_valid,)`` flat indices into the ``(num_chunks * max_len_after_cnn)`` grid. + """ + if kwargs is not None and (valid_indices := kwargs.pop("valid_indices", None)) is not None: + return valid_indices + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + max_len_after_cnn = feature_lens_after_cnn.max().item() + mask = torch.arange(max_len_after_cnn, device=chunk_lengths.device) < feature_lens_after_cnn.unsqueeze(1) + return mask.flatten().nonzero().squeeze(-1) + + +def get_audio_cu_seqlens( + chunk_lengths: torch.Tensor, + feature_lens: torch.Tensor, + n_window_infer: int, + n_window: int, + *, + kwargs: dict | None = None, +) -> torch.Tensor: + """Compute cumulative sequence lengths for audio attention windowing, or pop ``"cu_seqlens"`` from ``kwargs`` if precomputed. + + Splits each sample's post-CNN features into inference windows and returns + cumulative boundaries for flash-attention-style sequence packing. + + Args: + chunk_lengths: ``(num_chunks,)`` pre-CNN chunk lengths. + feature_lens: ``(batch_size,)`` per-sample frame counts. + n_window_infer: inference window size (in raw frames). + n_window: half the chunk size (in raw frames). + kwargs: optional caller kwargs — if it contains ``"cu_seqlens"`` it is popped and returned. + + Returns: + ``(num_windows + 1,)`` int32 cumulative sequence boundaries. + """ + if kwargs is not None and (cu_seqlens := kwargs.pop("cu_seqlens", None)) is not None: + return cu_seqlens + aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + max_len_after_cnn = feature_lens_after_cnn.max().item() + cu_chunk_lens = [0] + n_window_ratio = n_window_infer // (n_window * 2) + window_aftercnn = max_len_after_cnn * n_window_ratio + for cnn_len in aftercnn_lens: + cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) + remainder = cnn_len % window_aftercnn + if remainder != 0: + cu_chunk_lens += [remainder] + return torch.tensor(cu_chunk_lens, device=feature_lens.device).cumsum(-1, dtype=torch.int32) + + @dataclass @auto_docstring class BaseModelOutputWithDeepstackFeatures(BaseModelOutputWithPooling): @@ -900,28 +1001,23 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.conv2d1 = value - def forward( - self, - input_features, - feature_lens=None, - aftercnn_lens=None, - **kwargs, - ): - aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) - chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() - - chunk_lengths = torch.full((chunk_num.sum(),), self.n_window * 2, dtype=torch.long, device=feature_lens.device) - tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] - chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) - chunk_lengths[chunk_lengths == 0] = self.n_window * 2 - - chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) - padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) - feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) - padded_mask_after_cnn = nn.utils.rnn.pad_sequence( - [torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn], - batch_first=True, + @merge_with_config_defaults + @capture_outputs(tie_last_hidden_states=False) + @auto_docstring + def forward(self, input_features=None, feature_lens=None, **kwargs: Unpack[TransformersKwargs]): + r""" + feature_lens (`torch.LongTensor` of shape `(batch_size,)`): + mel length + """ + padded_feature, chunk_lengths = chunk_and_pad_features( + input_features, feature_lens, self.n_window, kwargs=kwargs ) + valid_indices = get_valid_indices(chunk_lengths, kwargs=kwargs) + cu_seqlens = get_audio_cu_seqlens( + chunk_lengths, feature_lens, self.n_window_infer, self.n_window, kwargs=kwargs + ) + + # Add channel dim for Conv2d: (num_chunks, mel_bins, time) -> (num_chunks, 1, mel_bins, time) padded_feature = padded_feature.unsqueeze(1) # Split to chunk to avoid OOM during convolution padded_embeds = [] @@ -931,6 +1027,7 @@ def forward( padded_embed = F.gelu(self.conv2d3(padded_embed)) padded_embeds.append(padded_embed) padded_embed = torch.cat(padded_embeds, dim=0) + b, c, f, t = padded_embed.size() padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)) @@ -940,22 +1037,13 @@ def forward( .to(padded_embed.dtype) ) padded_embed = padded_embed + positional_embedding - hidden_states = padded_embed[padded_mask_after_cnn] - cu_chunk_lens = [0] - window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2)) - for cnn_len in aftercnn_lens: - cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) - remainder = cnn_len % window_aftercnn - if remainder != 0: - cu_chunk_lens += [remainder] - cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32) + hidden_states = torch.index_select(padded_embed.reshape(-1, padded_embed.shape[-1]), 0, valid_indices) for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, cu_seqlens, ) - hidden_states = layer_outputs[0] hidden_states = self.ln_post(hidden_states) @@ -1098,6 +1186,7 @@ def __init__(self, config): self.num_experts_per_tok = config.text_config.num_experts_per_tok self.router_aux_loss_coef = config.text_config.router_aux_loss_coef + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -1113,8 +1202,9 @@ def get_video_features( The temporal, height and width of feature shape of each video in LLM. """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs) + return self.visual(pixel_values_videos, grid_thw=video_grid_thw, return_dict=True, **kwargs) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -1130,7 +1220,7 @@ def get_image_features( The temporal, height and width of feature shape of each image in LLM. """ pixel_values = pixel_values.type(self.visual.dtype) - return self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + return self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs) @can_return_tuple @auto_docstring @@ -1156,12 +1246,7 @@ def get_audio_features( audio_feature_lengths = None feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) - audio_outputs = self.audio_tower( - input_features, - feature_lens=feature_lens, - return_dict=True, - **kwargs, - ) + audio_outputs = self.audio_tower(input_features, feature_lens=feature_lens, return_dict=True, **kwargs) return audio_outputs @@ -1203,10 +1288,7 @@ def forward( # 2. Merge text , audios , image and video if input_features is not None: audio_features = self.get_audio_features( - input_features, - feature_attention_mask=feature_attention_mask, - audio_feature_lengths=audio_feature_lengths, - return_dict=True, + input_features, feature_attention_mask, audio_feature_lengths, return_dict=True, **kwargs ).last_hidden_state audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) _, _, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) @@ -1214,7 +1296,7 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features( - pixel_values, image_grid_thw, return_dict=True + pixel_values, image_grid_thw, return_dict=True, **kwargs ) image_embeds = image_outputs.pooler_output image_embeds_multiscale = image_outputs.deepstack_features @@ -1226,7 +1308,7 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features( - pixel_values_videos, video_grid_thw, return_dict=True + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs ) video_embeds = video_outputs.pooler_output video_embeds_multiscale = video_outputs.deepstack_features diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 9522cb354789..5a6103724c0e 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -19,13 +19,13 @@ # limitations under the License. import itertools +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional import torch import torch.nn as nn -import torch.nn.functional as F from ... import initialization as init from ...activations import ACT2FN @@ -40,8 +40,14 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids from .configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig @@ -99,10 +105,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class Qwen3VLVisionPatchMerger(nn.Module): @@ -660,107 +664,25 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - merge_size = self.spatial_merge_size - grid_thw_list = grid_thw.tolist() - - max_hw = max(max(h, w) for _, h, w in grid_thw_list) - freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) - device = freq_table.device - - total_tokens = sum(t * h * w for t, h, w in grid_thw_list) - pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) - - offset = 0 - for num_frames, height, width in grid_thw_list: - merged_h, merged_w = height // merge_size, width // merge_size - - block_rows = torch.arange(merged_h, device=device) # block row indices - block_cols = torch.arange(merged_w, device=device) # block col indices - intra_row = torch.arange(merge_size, device=device) # intra-block row offsets - intra_col = torch.arange(merge_size, device=device) # intra-block col offsets - - # Compute full-resolution positions - row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] - col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] - - row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - - coords = torch.stack((row_idx, col_idx), dim=-1) - - if num_frames > 1: - coords = coords.repeat(num_frames, 1) - - num_tokens = coords.shape[0] - pos_ids[offset : offset + num_tokens] = coords - offset += num_tokens - - embeddings = freq_table[pos_ids] # lookup rotary embeddings - embeddings = embeddings.flatten(1) - return embeddings + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb def fast_pos_embed_interpolate(self, grid_thw): - grid_thw_list = grid_thw.tolist() - grid_ts = [row[0] for row in grid_thw_list] - grid_hs = [row[1] for row in grid_thw_list] - grid_ws = [row[2] for row in grid_thw_list] - device = self.pos_embed.weight.device - - idx_list = [[] for _ in range(4)] - weight_list = [[] for _ in range(4)] - - for t, h, w in grid_thw_list: - h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) - w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) - - h_idxs_floor = h_idxs.int() - w_idxs_floor = w_idxs.int() - h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - - dh = h_idxs - h_idxs_floor - dw = w_idxs - w_idxs_floor - - base_h = h_idxs_floor * self.num_grid_per_side - base_h_ceil = h_idxs_ceil * self.num_grid_per_side - - indices = [ - (base_h[None].T + w_idxs_floor[None]).flatten(), - (base_h[None].T + w_idxs_ceil[None]).flatten(), - (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), - (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), - ] - - weights = [ - ((1 - dh)[None].T * (1 - dw)[None]).flatten(), - ((1 - dh)[None].T * dw[None]).flatten(), - (dh[None].T * (1 - dw)[None]).flatten(), - (dh[None].T * dw[None]).flatten(), - ] - - for i in range(4): - idx_list[i].extend(indices[i].tolist()) - weight_list[i].extend(weights[i].tolist()) - - idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) - weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) - pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] - patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] - - patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) - - patch_pos_embeds_permute = [] - merge_size = self.config.spatial_merge_size - for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): - pos_embed = pos_embed.repeat(t, 1) - pos_embed = ( - pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) - .permute(0, 1, 3, 2, 4, 5) - .flatten(0, 4) - ) - patch_pos_embeds_permute.append(pos_embed) - patch_pos_embeds = torch.cat(patch_pos_embeds_permute) - return patch_pos_embeds + warnings.warn( + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", + FutureWarning, + stacklevel=2, + ) + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) @merge_with_config_defaults @capture_outputs @@ -777,12 +699,19 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) - hidden_states = hidden_states + pos_embeds + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, + self.num_grid_per_side, + self.config.spatial_merge_size, + kwargs=kwargs, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + hidden_states = self.patch_embed(hidden_states) + pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) + hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(position_ids) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) @@ -790,16 +719,6 @@ def forward( emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks): hidden_states = blk( @@ -1123,6 +1042,7 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -1140,6 +1060,7 @@ def get_video_features( # Same implementation as for images return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -1288,7 +1209,7 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features( - pixel_values, image_grid_thw, return_dict=True + pixel_values, image_grid_thw, return_dict=True, **kwargs ) image_embeds = image_outputs.pooler_output deepstack_image_embeds = image_outputs.deepstack_features @@ -1300,7 +1221,7 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features( - pixel_values_videos, video_grid_thw, return_dict=True + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs ) video_embeds = video_outputs.pooler_output deepstack_video_embeds = video_outputs.deepstack_features @@ -1423,9 +1344,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) @auto_docstring def get_image_features( @@ -1440,7 +1359,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple def forward( diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index e2b1dd42a68b..32b093caafb0 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -13,6 +13,7 @@ # limitations under the License. """PyTorch Qwen3-VL model.""" +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any @@ -20,7 +21,6 @@ import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from huggingface_hub.dataclasses import strict from ... import initialization as init @@ -37,9 +37,10 @@ from ...processing_utils import ProcessingKwargs, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, can_return_tuple, logging -from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ...video_utils import VideoInput +from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids from ..llama.modeling_llama import LlamaRotaryEmbedding from ..qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLCausalLMOutputWithPast, @@ -446,107 +447,25 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - merge_size = self.spatial_merge_size - grid_thw_list = grid_thw.tolist() - - max_hw = max(max(h, w) for _, h, w in grid_thw_list) - freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) - device = freq_table.device - - total_tokens = sum(t * h * w for t, h, w in grid_thw_list) - pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) - - offset = 0 - for num_frames, height, width in grid_thw_list: - merged_h, merged_w = height // merge_size, width // merge_size - - block_rows = torch.arange(merged_h, device=device) # block row indices - block_cols = torch.arange(merged_w, device=device) # block col indices - intra_row = torch.arange(merge_size, device=device) # intra-block row offsets - intra_col = torch.arange(merge_size, device=device) # intra-block col offsets - - # Compute full-resolution positions - row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] - col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] - - row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - - coords = torch.stack((row_idx, col_idx), dim=-1) - - if num_frames > 1: - coords = coords.repeat(num_frames, 1) - - num_tokens = coords.shape[0] - pos_ids[offset : offset + num_tokens] = coords - offset += num_tokens - - embeddings = freq_table[pos_ids] # lookup rotary embeddings - embeddings = embeddings.flatten(1) - return embeddings + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb def fast_pos_embed_interpolate(self, grid_thw): - grid_thw_list = grid_thw.tolist() - grid_ts = [row[0] for row in grid_thw_list] - grid_hs = [row[1] for row in grid_thw_list] - grid_ws = [row[2] for row in grid_thw_list] - device = self.pos_embed.weight.device - - idx_list = [[] for _ in range(4)] - weight_list = [[] for _ in range(4)] - - for t, h, w in grid_thw_list: - h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) - w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) - - h_idxs_floor = h_idxs.int() - w_idxs_floor = w_idxs.int() - h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - - dh = h_idxs - h_idxs_floor - dw = w_idxs - w_idxs_floor - - base_h = h_idxs_floor * self.num_grid_per_side - base_h_ceil = h_idxs_ceil * self.num_grid_per_side - - indices = [ - (base_h[None].T + w_idxs_floor[None]).flatten(), - (base_h[None].T + w_idxs_ceil[None]).flatten(), - (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), - (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), - ] - - weights = [ - ((1 - dh)[None].T * (1 - dw)[None]).flatten(), - ((1 - dh)[None].T * dw[None]).flatten(), - (dh[None].T * (1 - dw)[None]).flatten(), - (dh[None].T * dw[None]).flatten(), - ] - - for i in range(4): - idx_list[i].extend(indices[i].tolist()) - weight_list[i].extend(weights[i].tolist()) - - idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) - weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) - pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] - patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] - - patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) - - patch_pos_embeds_permute = [] - merge_size = self.config.spatial_merge_size - for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): - pos_embed = pos_embed.repeat(t, 1) - pos_embed = ( - pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) - .permute(0, 1, 3, 2, 4, 5) - .flatten(0, 4) - ) - patch_pos_embeds_permute.append(pos_embed) - patch_pos_embeds = torch.cat(patch_pos_embeds_permute) - return patch_pos_embeds + warnings.warn( + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", + FutureWarning, + stacklevel=2, + ) + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) @merge_with_config_defaults @capture_outputs @@ -563,12 +482,19 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) - hidden_states = hidden_states + pos_embeds + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, + self.num_grid_per_side, + self.config.spatial_merge_size, + kwargs=kwargs, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + hidden_states = self.patch_embed(hidden_states) + pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) + hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(position_ids) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) @@ -576,16 +502,6 @@ def forward( emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks): hidden_states = blk( @@ -771,6 +687,7 @@ def get_rope_index( return super().get_rope_index(video_grid_thw=video_grid_thw, **super_kwargs) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -796,6 +713,7 @@ def get_image_features( return vision_output + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -846,7 +764,7 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features( - pixel_values, image_grid_thw, return_dict=True + pixel_values, image_grid_thw, return_dict=True, **kwargs ) image_embeds = image_outputs.pooler_output deepstack_image_embeds = image_outputs.deepstack_features @@ -858,7 +776,7 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features( - pixel_values_videos, video_grid_thw, return_dict=True + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs ) video_embeds = video_outputs.pooler_output deepstack_video_embeds = video_outputs.deepstack_features diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index be248a160e7d..c3fc822ff068 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -19,6 +19,7 @@ # limitations under the License. import itertools +import warnings from collections.abc import Callable from dataclasses import dataclass from typing import Any, Optional @@ -45,8 +46,14 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.generic import ( + handle_extra_kwargs, + is_flash_attention_requested, + maybe_autocast, + merge_with_config_defaults, +) from ...utils.output_capturing import OutputRecorder, capture_outputs +from ...vision_utils import get_vision_bilinear_indices_and_weights, get_vision_cu_seqlens, get_vision_position_ids from .configuration_qwen3_vl_moe import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig, Qwen3VLMoeVisionConfig @@ -399,10 +406,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) def apply_rotary_pos_emb_vision( @@ -644,107 +649,25 @@ def __init__(self, config, *inputs, **kwargs) -> None: self.post_init() def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - merge_size = self.spatial_merge_size - grid_thw_list = grid_thw.tolist() - - max_hw = max(max(h, w) for _, h, w in grid_thw_list) - freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) - device = freq_table.device - - total_tokens = sum(t * h * w for t, h, w in grid_thw_list) - pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) - - offset = 0 - for num_frames, height, width in grid_thw_list: - merged_h, merged_w = height // merge_size, width // merge_size - - block_rows = torch.arange(merged_h, device=device) # block row indices - block_cols = torch.arange(merged_w, device=device) # block col indices - intra_row = torch.arange(merge_size, device=device) # intra-block row offsets - intra_col = torch.arange(merge_size, device=device) # intra-block col offsets - - # Compute full-resolution positions - row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] - col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] - - row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - - coords = torch.stack((row_idx, col_idx), dim=-1) - - if num_frames > 1: - coords = coords.repeat(num_frames, 1) - - num_tokens = coords.shape[0] - pos_ids[offset : offset + num_tokens] = coords - offset += num_tokens - - embeddings = freq_table[pos_ids] # lookup rotary embeddings - embeddings = embeddings.flatten(1) - return embeddings + warnings.warn( + f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.", + FutureWarning, + stacklevel=2, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + return rotary_pos_emb def fast_pos_embed_interpolate(self, grid_thw): - grid_thw_list = grid_thw.tolist() - grid_ts = [row[0] for row in grid_thw_list] - grid_hs = [row[1] for row in grid_thw_list] - grid_ws = [row[2] for row in grid_thw_list] - device = self.pos_embed.weight.device - - idx_list = [[] for _ in range(4)] - weight_list = [[] for _ in range(4)] - - for t, h, w in grid_thw_list: - h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) - w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) - - h_idxs_floor = h_idxs.int() - w_idxs_floor = w_idxs.int() - h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - - dh = h_idxs - h_idxs_floor - dw = w_idxs - w_idxs_floor - - base_h = h_idxs_floor * self.num_grid_per_side - base_h_ceil = h_idxs_ceil * self.num_grid_per_side - - indices = [ - (base_h[None].T + w_idxs_floor[None]).flatten(), - (base_h[None].T + w_idxs_ceil[None]).flatten(), - (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), - (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), - ] - - weights = [ - ((1 - dh)[None].T * (1 - dw)[None]).flatten(), - ((1 - dh)[None].T * dw[None]).flatten(), - (dh[None].T * (1 - dw)[None]).flatten(), - (dh[None].T * dw[None]).flatten(), - ] - - for i in range(4): - idx_list[i].extend(indices[i].tolist()) - weight_list[i].extend(weights[i].tolist()) - - idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) - weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) - pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] - patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] - - patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) - - patch_pos_embeds_permute = [] - merge_size = self.config.spatial_merge_size - for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): - pos_embed = pos_embed.repeat(t, 1) - pos_embed = ( - pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) - .permute(0, 1, 3, 2, 4, 5) - .flatten(0, 4) - ) - patch_pos_embeds_permute.append(pos_embed) - patch_pos_embeds = torch.cat(patch_pos_embeds_permute) - return patch_pos_embeds + warnings.warn( + f"`{self.__class__.__name__}.fast_pos_embed_interpolate` is deprecated and will be removed in a future version. Use `get_vision_bilinear_indices_and_weights` from `transformers.vision_utils` and apply `self.pos_embed`.", + FutureWarning, + stacklevel=2, + ) + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, self.num_grid_per_side, self.config.spatial_merge_size + ) + return (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) @merge_with_config_defaults @capture_outputs @@ -761,12 +684,19 @@ def forward( Returns: `torch.Tensor`: hidden_states. """ - hidden_states = self.patch_embed(hidden_states) - - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) - hidden_states = hidden_states + pos_embeds + bilinear_indices, bilinear_weights = get_vision_bilinear_indices_and_weights( + grid_thw, + self.num_grid_per_side, + self.config.spatial_merge_size, + kwargs=kwargs, + ) + position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) - rotary_pos_emb = self.rot_pos_emb(grid_thw) + hidden_states = self.patch_embed(hidden_states) + pos_embeds = (self.pos_embed(bilinear_indices) * bilinear_weights[:, :, None]).sum(0) + hidden_states = hidden_states + pos_embeds.to(hidden_states.dtype) + rotary_pos_emb = self.rotary_pos_emb(position_ids) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) @@ -774,16 +704,6 @@ def forward( emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks): hidden_states = blk( @@ -1252,6 +1172,7 @@ def get_rope_index( mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -1269,6 +1190,7 @@ def get_video_features( # Same implementation as for images return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -1417,7 +1339,7 @@ def forward( if pixel_values is not None: image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features( - pixel_values, image_grid_thw, return_dict=True + pixel_values, image_grid_thw, return_dict=True, **kwargs ) image_embeds = image_outputs.pooler_output deepstack_image_embeds = image_outputs.deepstack_features @@ -1429,7 +1351,7 @@ def forward( if pixel_values_videos is not None: video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features( - pixel_values_videos, video_grid_thw, return_dict=True + pixel_values_videos, video_grid_thw, return_dict=True, **kwargs ) video_embeds = video_outputs.pooler_output deepstack_video_embeds = video_outputs.deepstack_features @@ -1605,9 +1527,7 @@ def get_video_features( video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. """ - return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs) @auto_docstring def get_image_features( @@ -1622,7 +1542,7 @@ def get_image_features( image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs) @can_return_tuple def forward( diff --git a/src/transformers/models/video_llama_3/configuration_video_llama_3.py b/src/transformers/models/video_llama_3/configuration_video_llama_3.py index 55230680d477..ce76ad960a28 100644 --- a/src/transformers/models/video_llama_3/configuration_video_llama_3.py +++ b/src/transformers/models/video_llama_3/configuration_video_llama_3.py @@ -69,7 +69,7 @@ class VideoLlama3Config(PreTrainedConfig): vision_config: dict | PreTrainedConfig | None = None image_token_id: int = 151655 video_token_id: int = 151656 - tie_word_embeddings: bool = False + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.vision_config, dict): diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index 26d89b313167..18474c2739aa 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -34,8 +34,9 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check -from ...utils.generic import is_flash_attention_requested, merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, is_flash_attention_requested, merge_with_config_defaults from ...utils.output_capturing import capture_outputs +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from ..auto.modeling_auto import AutoModel from .configuration_video_llama_3 import VideoLlama3Config, VideoLlama3VisionConfig @@ -50,39 +51,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, grid_thw, merge_sizes) -> tuple[torch.Tensor, torch.Tensor]: - pos_ids = [] - for (t, h, w), merge_size in zip(grid_thw, merge_sizes): - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // merge_size, - merge_size, - w // merge_size, - merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // merge_size, - merge_size, - w // merge_size, - merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_thw = grid_thw[:, 1:].max() - - seq = torch.arange(max_grid_thw, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - rotary_pos_emb_full = torch.outer(seq, self.inv_freq) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - - return (emb.cos(), emb.sin()) + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1) class VideoLlama3VisionEmbeddings(nn.Module): @@ -452,19 +422,14 @@ def forward( merge_sizes (`torch.Tensor` of shape `(num_images_or_videos,)`): The spatial downsampling ratio of each image or video feature. """ - position_embeddings = self.rotary_pos_emb(grid_thw, merge_sizes) - - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) + position_ids = get_vision_position_ids(grid_thw, merge_sizes, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.embeddings(pixel_values.type(self.dtype)) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + encoder_outputs: BaseModelOutput = self.encoder( hidden_states, cu_seqlens=cu_seqlens, @@ -546,6 +511,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -570,6 +536,7 @@ def get_video_features( **kwargs, ) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -681,7 +648,7 @@ def forward( image_embeds = None if pixel_values is not None: image_embeds = self.get_image_features( - pixel_values, image_grid_thw, image_merge_sizes, return_dict=True + pixel_values, image_grid_thw, image_merge_sizes, return_dict=True, **kwargs ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( @@ -692,7 +659,7 @@ def forward( video_embeds = None if pixel_values_videos is not None: video_embeds = self.get_video_features( - pixel_values_videos, video_grid_thw, video_merge_sizes, return_dict=True + pixel_values_videos, video_grid_thw, video_merge_sizes, return_dict=True, **kwargs ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) if video_compression_mask is not None: @@ -774,11 +741,14 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + @handle_extra_kwargs(modality="video") + @can_return_tuple @auto_docstring def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: torch.LongTensor | None = None, + video_merge_sizes: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -786,16 +756,19 @@ def get_video_features( The tensors corresponding to the input videos. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + video_merge_sizes (`torch.Tensor` of shape `(num_videos,)`, *optional*): + The spatial downsampling ratio of each video feature. """ - return self.model.get_video_features( - pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs - ) + return self.model.get_video_features(pixel_values_videos, video_grid_thw, video_merge_sizes, **kwargs) + @handle_extra_kwargs(modality="image") + @can_return_tuple @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None, + image_merge_sizes: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -803,8 +776,10 @@ def get_image_features( The tensors corresponding to the input images. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): The temporal, height and width of feature shape of each image in LLM. + image_merge_sizes (`torch.Tensor` of shape `(num_images,)`, *optional*): + The spatial downsampling ratio of each image feature. """ - return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs) + return self.model.get_image_features(pixel_values, image_grid_thw, image_merge_sizes, **kwargs) @can_return_tuple @auto_docstring diff --git a/src/transformers/models/video_llama_3/modular_video_llama_3.py b/src/transformers/models/video_llama_3/modular_video_llama_3.py index 159ffdca6371..e06e36607efd 100644 --- a/src/transformers/models/video_llama_3/modular_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modular_video_llama_3.py @@ -45,13 +45,14 @@ can_return_tuple, logging, ) -from ...utils.generic import is_flash_attention_requested, merge_with_config_defaults +from ...utils.generic import handle_extra_kwargs, is_flash_attention_requested, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ...video_utils import ( VideoInput, group_videos_by_shape, reorder_videos, ) +from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from ..auto import CONFIG_MAPPING, AutoConfig from ..auto.modeling_auto import AutoModel from ..qwen2_vl.image_processing_pil_qwen2_vl import Qwen2VLImageProcessorPil @@ -105,7 +106,7 @@ class VideoLlama3Config(PreTrainedConfig): vision_config: dict | PreTrainedConfig | None = None image_token_id: int = 151655 video_token_id: int = 151656 - tie_word_embeddings: bool = False + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.vision_config, dict): @@ -128,39 +129,7 @@ def __post_init__(self, **kwargs): class VideoLlama3VisionRotaryEmbedding(VisionRotaryEmbedding): - def forward(self, grid_thw, merge_sizes) -> tuple[torch.Tensor, torch.Tensor]: - pos_ids = [] - for (t, h, w), merge_size in zip(grid_thw, merge_sizes): - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // merge_size, - merge_size, - w // merge_size, - merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // merge_size, - merge_size, - w // merge_size, - merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_thw = grid_thw[:, 1:].max() - - seq = torch.arange(max_grid_thw, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - rotary_pos_emb_full = torch.outer(seq, self.inv_freq) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - - return (emb.cos(), emb.sin()) + pass class VideoLlama3VisionEmbeddings(nn.Module): @@ -421,19 +390,14 @@ def forward( merge_sizes (`torch.Tensor` of shape `(num_images_or_videos,)`): The spatial downsampling ratio of each image or video feature. """ - position_embeddings = self.rotary_pos_emb(grid_thw, merge_sizes) - - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - # Select dtype based on the following factors: - # - FA2 requires that cu_seqlens_q must have dtype int32 - # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw - # See https://github.com/huggingface/transformers/pull/34852 for more information - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, - ) - cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) + position_ids = get_vision_position_ids(grid_thw, merge_sizes, kwargs=kwargs) + cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs) hidden_states = self.embeddings(pixel_values.type(self.dtype)) + rotary_pos_emb = self.rotary_pos_emb(position_ids) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + encoder_outputs: BaseModelOutput = self.encoder( hidden_states, cu_seqlens=cu_seqlens, @@ -514,6 +478,7 @@ def get_vision_position_ids(self): def compute_3d_position_ids(self): raise AttributeError("Not needed for VideoLLaMA3") + @handle_extra_kwargs(modality="video") @can_return_tuple @auto_docstring def get_video_features( @@ -538,6 +503,7 @@ def get_video_features( **kwargs, ) + @handle_extra_kwargs(modality="image") @can_return_tuple @auto_docstring def get_image_features( @@ -608,7 +574,7 @@ def forward( image_embeds = None if pixel_values is not None: image_embeds = self.get_image_features( - pixel_values, image_grid_thw, image_merge_sizes, return_dict=True + pixel_values, image_grid_thw, image_merge_sizes, return_dict=True, **kwargs ).pooler_output image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_placeholder_mask( @@ -619,7 +585,7 @@ def forward( video_embeds = None if pixel_values_videos is not None: video_embeds = self.get_video_features( - pixel_values_videos, video_grid_thw, video_merge_sizes, return_dict=True + pixel_values_videos, video_grid_thw, video_merge_sizes, return_dict=True, **kwargs ).pooler_output video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) if video_compression_mask is not None: @@ -690,6 +656,46 @@ class VideoLlama3ForConditionalGeneration(Qwen2VLForConditionalGeneration): def __init__(self, config: VideoLlama3Config): super().__init__(config) # just to add type hint on config + @handle_extra_kwargs(modality="image") + @can_return_tuple + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: torch.LongTensor | None = None, + image_merge_sizes: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + image_merge_sizes (`torch.Tensor` of shape `(num_images,)`, *optional*): + The spatial downsampling ratio of each image feature. + """ + return self.model.get_image_features(pixel_values, image_grid_thw, image_merge_sizes, **kwargs) + + @handle_extra_kwargs(modality="video") + @can_return_tuple + @auto_docstring + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_grid_thw: torch.LongTensor | None = None, + video_merge_sizes: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + video_merge_sizes (`torch.Tensor` of shape `(num_videos,)`, *optional*): + The spatial downsampling ratio of each video feature. + """ + return self.model.get_video_features(pixel_values_videos, video_grid_thw, video_merge_sizes, **kwargs) + @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index f64cd5968552..6d754d1ecf7a 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -29,12 +29,12 @@ class Zamba2Config(PreTrainedConfig): Number of groups for the evolution matrices of mamba 2. n_mamba_heads (`int`, *optional*, defaults to 8): Number of heads for the evolution matrices of mamba 2. + use_mamba_kernels (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use the fast mamba kernels. use_conv_bias (`bool`, *optional*, defaults to `True`): Whether or not to use bias in the convolution layer of the mixer block. chunk_size (`int`, *optional*, defaults to 256): Size of the chunks that will comprise the sequence. - use_mamba_kernels (`bool`, *optional*, defaults to `True`): - Flag indicating whether or not to use the fast mamba kernels. use_mem_eff_path (`bool`, *optional*, defaults to `False`): Whether or not to use the fused conv1d and scan in mamba2 layers. add_bias_linear (`bool`, *optional*, defaults to `False`): diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index c6b5960f0849..d3b5c62b9373 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -905,6 +905,45 @@ def wrapper(self, *args, **kwargs): return wrapper +_KNOWN_MODALITIES = ("image", "video", "audio") + + +def handle_extra_kwargs(modality: str): + """ + Decorator for ``get__features`` methods that: + - strips the modality prefix from incoming kwargs whose stripped name isn't an existing + parameter (e.g. ``image_cu_seqlens`` → ``cu_seqlens``, forwarded via ``**kwargs``); + - drops kwargs prefixed with another known modality (e.g. ``video_*`` passed to an + image method), so an outer ``forward()`` can blindly forward ``**kwargs`` to each + modality method without leaking the wrong tensors into the wrong encoder; + - leaves everything else untouched (including kwargs that match a named parameter). + + Used so multimodal models can accept arbitrary precomputed tensors (``image_cu_seqlens``, + ``video_position_ids``, …) without enumerating each one in every signature. + """ + prefix = f"{modality}_" + other_prefixes = tuple(f"{m}_" for m in _KNOWN_MODALITIES if m != modality) + + def decorator(func): + existing_params = set(inspect.signature(func).parameters) + + @wraps(func) + def wrapper(*args, **kwargs): + translated = {} + for k, v in kwargs.items(): + if k.startswith(other_prefixes): + continue + if k.startswith(prefix) and k not in existing_params: + translated[k.removeprefix(prefix)] = v + else: + translated[k] = v + return func(*args, **translated) + + return wrapper + + return decorator + + def merge_with_config_defaults(func): """ Decorator using config field (if they exist) as default value for some args and kwargs. Precedence is always diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index de11d23cbecf..146a458fd77d 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1522,6 +1522,16 @@ def torch_compilable_check(cond: Any, msg: str | Callable[[], str], error_type: import torch + # When tracing, msg may be an f-string with tensor values that dynamo can't trace + # (callable/isinstance on it breaks). Check compilation first and use torch._check + # without msg (it only serves as a compiler hint in that case). + if is_tracing(): + if isinstance(cond, torch.Tensor): + torch._check_tensor_all(cond) + else: + torch._check(cond) + return + if not callable(msg): # torch._check requires msg to be a callable but we want to keep the API simple for users def msg_callable(): diff --git a/src/transformers/vision_utils.py b/src/transformers/vision_utils.py new file mode 100644 index 000000000000..528dab1b81ef --- /dev/null +++ b/src/transformers/vision_utils.py @@ -0,0 +1,220 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Vision utility functions for pre-computing very dynamic and +data-dependent tensors that can break model graph capturing. + +All functions are standalone (no model weights) and compute tensors from +``grid_thw`` + config scalars. They are used by vision encoders and can be +precomputed before `torch.compile` / ``torch.export`` tracing since they +use untraceable ops (``repeat_interleave``, ``.tolist()``, ``nonzero()``, loops). + +Each ``get_*`` accepts an optional ``kwargs`` dict; if it contains the +precomputed tensor under the natural key (``"cu_seqlens"``, ``"position_ids"``, +…), the function pops and returns it instead of computing. Vision encoders +write ``x = get_vision_x(..., kwargs=kwargs)`` and the matching key is +removed from the caller's kwargs as a side-effect of the pop. +""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F + + +def get_vision_cu_seqlens(grid_thw: torch.Tensor, *, kwargs: dict | None = None) -> torch.Tensor: + """Get cumulative sequence lengths from vision grid info, or pop from ``kwargs`` if precomputed. + + Args: + grid_thw: ``(num_images_or_videos, 3)`` — temporal, height, width per entry. + kwargs: optional caller kwargs — if it contains ``"cu_seqlens"`` it is popped and returned. + + Returns: + ``cu_seqlens``: ``(total_patches + 1,)`` int32 cumulative sequence boundaries. + """ + if kwargs is not None and (cu_seqlens := kwargs.pop("cu_seqlens", None)) is not None: + return cu_seqlens + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32 + ) + return F.pad(cu_seqlens, (1, 0), value=0) + + +def get_vision_position_ids( + grid_thw: torch.Tensor, spatial_merge_size: int | torch.Tensor, *, kwargs: dict | None = None +) -> torch.Tensor: + """Get (row, col) position IDs for vision rotary embeddings, or pop from ``kwargs`` if precomputed. + + Args: + grid_thw: ``(num_images_or_videos, 3)`` + spatial_merge_size: merge block size — either a single ``int`` (same for all images) + or a ``(num_images_or_videos,)`` tensor (per-image). + kwargs: optional caller kwargs — if it contains ``"position_ids"`` it is popped and returned. + + Returns: + ``position_ids``: ``(total_tokens, 2)`` long — (row, col) position per token. + """ + if kwargs is not None and (position_ids := kwargs.pop("position_ids", None)) is not None: + return position_ids + device = grid_thw.device + if isinstance(spatial_merge_size, int): + spatial_merge_size = torch.tensor([spatial_merge_size], device=device).expand(len(grid_thw)) + + position_ids = [] + for (t, h, w), m in zip(grid_thw.tolist(), spatial_merge_size.tolist()): + t, h, w, m = int(t), int(h), int(w), int(m) + hpos_ids = torch.arange(h, device=device).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape(h // m, m, w // m, m).permute(0, 2, 1, 3).flatten() + + wpos_ids = torch.arange(w, device=device).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape(h // m, m, w // m, m).permute(0, 2, 1, 3).flatten() + position_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + return torch.cat(position_ids, dim=0) + + +def get_vision_window_index( + grid_thw: torch.Tensor, + spatial_merge_size: int, + window_size: int, + patch_size: int, + spatial_merge_unit: int, + *, + kwargs: dict | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Get window attention indices, or pop ``"window_index"``/``"cu_window_seqlens"`` from ``kwargs`` if both precomputed. + + Args: + grid_thw: ``(num_images_or_videos, 3)`` + spatial_merge_size: merge block size from vision config. + window_size: window size from vision config. + patch_size: patch size from vision config. + spatial_merge_unit: ``spatial_merge_size ** 2``. + kwargs: optional caller kwargs — if it contains both ``"window_index"`` and ``"cu_window_seqlens"`` they are popped and returned. + + Returns: + ``window_index``: ``(total_tokens,)`` long — reorder indices for windowed attention. + ``cu_window_seqlens``: ``(num_windows + 1,)`` int32 — cumulative window boundaries. + """ + if kwargs is not None: + window_index = kwargs.pop("window_index", None) + cu_window_seqlens = kwargs.pop("cu_window_seqlens", None) + if window_index is not None and cu_window_seqlens is not None: + return window_index, cu_window_seqlens + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = window_size // spatial_merge_size // patch_size + + for grid_t, grid_h, grid_w in grid_thw.tolist(): + grid_t, grid_h, grid_w = int(grid_t), int(grid_h), int(grid_w) + llm_grid_h = grid_h // spatial_merge_size + llm_grid_w = grid_w // spatial_merge_size + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, num_windows_h, vit_merger_window_size, num_windows_w, vit_merger_window_size + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, num_windows_h * num_windows_w, vit_merger_window_size, vit_merger_window_size + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += grid_t * llm_grid_h * llm_grid_w + + window_index = torch.cat(window_index, dim=0) + cu_window_seqlens = torch.tensor(cu_window_seqlens, device=grid_thw.device, dtype=torch.int32) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + return window_index, cu_window_seqlens + + +def get_vision_bilinear_indices_and_weights( + grid_thw: torch.Tensor, + num_grid_per_side: int, + spatial_merge_size: int, + *, + kwargs: dict | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Get bilinear interpolation indices/weights, or pop ``"bilinear_indices"``/``"bilinear_weights"`` from ``kwargs`` if both precomputed. + + Args: + grid_thw: ``(num_images_or_videos, 3)`` + num_grid_per_side: ``int(num_position_embeddings ** 0.5)`` from vision config. + spatial_merge_size: merge block size from vision config. + kwargs: optional caller kwargs — if it contains both ``"bilinear_indices"`` and ``"bilinear_weights"`` they are popped and returned. + + Returns: + ``bilinear_indices``: ``(4, total_thw)`` long — bilinear corner indices into pos_embed table. + ``bilinear_weights``: ``(4, total_thw)`` float — interpolation weights. + """ + if kwargs is not None: + bilinear_indices = kwargs.pop("bilinear_indices", None) + bilinear_weights = kwargs.pop("bilinear_weights", None) + if bilinear_indices is not None and bilinear_weights is not None: + return bilinear_indices, bilinear_weights + N = num_grid_per_side + m = spatial_merge_size + device = grid_thw.device + + idx_parts: list[list[torch.Tensor]] = [[] for _ in range(4)] + weight_parts: list[list[torch.Tensor]] = [[] for _ in range(4)] + + for t, h, w in grid_thw.tolist(): + t, h, w = int(t), int(h), int(w) + + h_idxs = torch.linspace(0, N - 1, h, device=device) + w_idxs = torch.linspace(0, N - 1, w, device=device) + + h_floor = h_idxs.int() + w_floor = w_idxs.int() + h_ceil = (h_floor + 1).clamp(max=N - 1) + w_ceil = (w_floor + 1).clamp(max=N - 1) + + dh = h_idxs - h_floor + dw = w_idxs - w_floor + + bh_f = h_floor * N + bh_c = h_ceil * N + + raw_idx = [ + (bh_f[:, None] + w_floor[None, :]).flatten(), + (bh_f[:, None] + w_ceil[None, :]).flatten(), + (bh_c[:, None] + w_floor[None, :]).flatten(), + (bh_c[:, None] + w_ceil[None, :]).flatten(), + ] + raw_w = [ + ((1 - dh)[:, None] * (1 - dw)[None, :]).flatten(), + ((1 - dh)[:, None] * dw[None, :]).flatten(), + (dh[:, None] * (1 - dw)[None, :]).flatten(), + (dh[:, None] * dw[None, :]).flatten(), + ] + + h_idx = torch.arange(h, device=device).view(h // m, m) + w_idx = torch.arange(w, device=device).view(w // m, m) + reorder = (h_idx[:, :, None, None] * w + w_idx[None, None, :, :]).permute(0, 2, 1, 3).flatten().repeat(t) + + for i in range(4): + idx_parts[i].append(raw_idx[i][reorder]) + weight_parts[i].append(raw_w[i][reorder]) + + bilinear_indices = torch.stack([torch.cat(p) for p in idx_parts]) + bilinear_weights = torch.stack([torch.cat(p) for p in weight_parts]) + return bilinear_indices, bilinear_weights diff --git a/tests/models/glm4v_moe/test_modeling_glm4v_moe.py b/tests/models/glm4v_moe/test_modeling_glm4v_moe.py index 917a13bcefaf..dca1361cb597 100644 --- a/tests/models/glm4v_moe/test_modeling_glm4v_moe.py +++ b/tests/models/glm4v_moe/test_modeling_glm4v_moe.py @@ -51,7 +51,7 @@ def __init__( self, parent, batch_size=3, - seq_length=7, + seq_length=64, num_channels=3, ignore_index=-100, image_size=112, @@ -72,7 +72,7 @@ def __init__( "output_channels": 64, "hidden_act": "silu", "max_position_embeddings": 512, - "rope_parameters": {"type": "default", "mrope_section": [1, 1]}, + "rope_parameters": {"type": "default", "mrope_section": [2, 2], "partial_rotary_factor": 1.0}, "rope_theta": 10000, "tie_word_embeddings": True, "bos_token_id": 0, diff --git a/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py b/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py index 7de1544d6893..55d3137a4ab7 100644 --- a/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py +++ b/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py @@ -460,7 +460,7 @@ def test_get_rope_index_video_with_audio(self): image_grid_thw = torch.empty((0, 3), dtype=torch.long) # 3 * 2 * 2 = 12 video tokens - video_grid_thw = torch.tensor([[3, 2, 2]], dtype=torch.long) + video_grid_thw = torch.tensor([[3, 2, 2]], dtype=torch.long, device=torch_device) # num_audio_tokens = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1 # i.e.: 300 audio_seqlen -> 75 audio tokens diff --git a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py index 5a425b434e7d..b956c624086e 100644 --- a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py +++ b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py @@ -328,7 +328,7 @@ def test_video_forward(self): C * T * (P**2), ] ) - video_grid_thw = torch.tensor([[patch_T, patch_H, patch_W]] * B) + video_grid_thw = torch.tensor([[patch_T, patch_H, patch_W]] * B, device=torch_device) # sanity check assert pixel_values_videos.shape[0] == video_grid_thw.prod(dim=1).sum().item() diff --git a/tests/models/qwen3_5/test_modeling_qwen3_5.py b/tests/models/qwen3_5/test_modeling_qwen3_5.py index 668a4e513970..6b1f7a5a44ef 100644 --- a/tests/models/qwen3_5/test_modeling_qwen3_5.py +++ b/tests/models/qwen3_5/test_modeling_qwen3_5.py @@ -497,7 +497,7 @@ def test_image_forward(self): channels * temporal_patch * (patch_size**2), ] ) - image_grid_thw = torch.tensor([[1, 1, 1]] * (bsz * num_images)) + image_grid_thw = torch.tensor([[1, 1, 1]] * (bsz * num_images), device=torch_device) self.assertEqual(pixel_values.shape[0], image_grid_thw.prod(dim=1).sum().item()) insertion_point = 0 @@ -550,7 +550,7 @@ def test_video_forward(self): ] ) - video_grid_thw = torch.tensor([[patch_t, patch_h, patch_w]] * (bsz * num_video)) + video_grid_thw = torch.tensor([[patch_t, patch_h, patch_w]] * (bsz * num_video), device=torch_device) self.assertEqual(pixel_values_videos.shape[0], video_grid_thw.prod(dim=1).sum().item()) input_ids[:, -1] = self.model_tester.pad_token_id diff --git a/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py b/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py index e81e4d951917..c5935af0a3be 100644 --- a/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py +++ b/tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py @@ -454,7 +454,7 @@ def test_image_forward(self): channels * temporal_patch * (patch_size**2), ] ) - image_grid_thw = torch.tensor([[1, 1, 1]] * (bsz * num_images)) + image_grid_thw = torch.tensor([[1, 1, 1]] * (bsz * num_images), device=torch_device) self.assertEqual(pixel_values.shape[0], image_grid_thw.prod(dim=1).sum().item()) insertion_point = 0 @@ -507,7 +507,7 @@ def test_video_forward(self): ] ) - video_grid_thw = torch.tensor([[patch_t, patch_h, patch_w]] * (bsz * num_video)) + video_grid_thw = torch.tensor([[patch_t, patch_h, patch_w]] * (bsz * num_video), device=torch_device) self.assertEqual(pixel_values_videos.shape[0], video_grid_thw.prod(dim=1).sum().item()) input_ids[:, -1] = self.model_tester.pad_token_id diff --git a/tests/models/qwen3_omni_moe/test_modeling_qwen3_omni_moe.py b/tests/models/qwen3_omni_moe/test_modeling_qwen3_omni_moe.py index 3d4113c922ae..fd9bbf0021a6 100644 --- a/tests/models/qwen3_omni_moe/test_modeling_qwen3_omni_moe.py +++ b/tests/models/qwen3_omni_moe/test_modeling_qwen3_omni_moe.py @@ -487,7 +487,7 @@ def test_get_rope_index_video_with_audio(self): image_grid_thw = torch.empty((0, 3), dtype=torch.long) # 3 * 2 * 2 = 12 video tokens - video_grid_thw = torch.tensor([[3, 2, 2]], dtype=torch.long) + video_grid_thw = torch.tensor([[3, 2, 2]], dtype=torch.long, device=torch_device) # num_audio_tokens = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1 # i.e.: 300 audio_seqlen -> 75 audio tokens @@ -859,7 +859,7 @@ def test_small_model_integration_test_multiturn(self): **inputs, thinker_temperature=0, thinker_do_sample=False, return_audio=False, thinker_max_new_tokens=20 ) - EXPECTED_DECODED_TEXT = "user\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog appears to be a Labrador Retriever.\nuser\nHow about this one?\nassistant\nThe sound is a person coughing." + EXPECTED_DECODED_TEXT = "user\nWhat's that sound and what kind of dog is this?\nassistant\nThe sound is glass shattering, and the dog appears to be a Labrador Retriever.\nuser\nHow about this one?\nassistant\nThis is the sound of a person coughing." self.assertEqual( self.processor.decode(output[0], skip_special_tokens=True), diff --git a/tests/models/qwen3_vl/test_modeling_qwen3_vl.py b/tests/models/qwen3_vl/test_modeling_qwen3_vl.py index 9874ce4a8203..0d3a0a1bec04 100644 --- a/tests/models/qwen3_vl/test_modeling_qwen3_vl.py +++ b/tests/models/qwen3_vl/test_modeling_qwen3_vl.py @@ -272,7 +272,7 @@ def test_image_forward(self): C * T * (P**2), ] ) - image_grid_thw = torch.tensor([[1, 1, 1]] * (B * num_images)) + image_grid_thw = torch.tensor([[1, 1, 1]] * (B * num_images), device=torch_device) self.assertEqual(pixel_values.shape[0], image_grid_thw.prod(dim=1).sum().item()) insertion_point = 0 @@ -327,7 +327,7 @@ def test_video_forward(self): ] ) - video_grid_thw = torch.tensor([[patch_T, patch_H, patch_W]] * (B * num_video)) + video_grid_thw = torch.tensor([[patch_T, patch_H, patch_W]] * (B * num_video), device=torch_device) # sanity check self.assertEqual(pixel_values_videos.shape[0], video_grid_thw.prod(dim=1).sum().item()) diff --git a/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py index 0b0523de3b71..e3dda9cd7b17 100644 --- a/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py +++ b/tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py @@ -217,7 +217,7 @@ def test_image_forward(self): C * T * (P**2), ] ) - image_grid_thw = torch.tensor([[1, 1, 1]] * (B * num_images)) + image_grid_thw = torch.tensor([[1, 1, 1]] * (B * num_images), device=torch_device) self.assertEqual(pixel_values.shape[0], image_grid_thw.prod(dim=1).sum().item()) insertion_point = 0 @@ -272,7 +272,7 @@ def test_video_forward(self): ] ) - video_grid_thw = torch.tensor([[patch_T, patch_H, patch_W]] * (B * num_video)) + video_grid_thw = torch.tensor([[patch_T, patch_H, patch_W]] * (B * num_video), device=torch_device) # sanity check self.assertEqual(pixel_values_videos.shape[0], video_grid_thw.prod(dim=1).sum().item()) @@ -420,7 +420,7 @@ def test_small_model_integration_test(self): inputs = inputs.to(torch_device) output = model.generate(**inputs, max_new_tokens=30, do_sample=False) - EXPECTED_DECODED_TEXT = "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a small wild cat native to the grasslands and steppes" + EXPECTED_DECODED_TEXT = "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a small wild cat native to the steppes and montane" self.assertEqual( self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT, @@ -498,7 +498,7 @@ def test_small_model_integration_test_expand(self): EXPECTED_DECODED_TEXT = [ "user\nWhat kind of dog is this?\nassistant\nThe animal in the image is not a dog. It is a **Pallas's cat** (*Otocolobus manul*), also known", - "user\nWhat kind of dog is this?\nassistant\nThe animal in the image is not a dog. It is a **Pallas's cat** (also known as the manul), a wild f" + "user\nWhat kind of dog is this?\nassistant\nThe animal in the image is not a dog. It is a **Pallas's cat** (also known as the **manul**), a" ] # fmt: skip self.assertEqual( self.processor.batch_decode(output, skip_special_tokens=True), @@ -548,7 +548,7 @@ def test_small_model_integration_test_batch_wo_image(self): output = model.generate(**inputs, max_new_tokens=30, do_sample=False) EXPECTED_DECODED_TEXT = [ - "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a wild cat species native to the grasslands and steppes", + "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a wild cat species native to the steppes and montane", "user\nWho are you?\nassistant\nI am Qwen, a large-scale language model developed by Alibaba Cloud's Tongyi Lab. I can assist you with answering questions, creating text such" ] # fmt: skip self.assertEqual( @@ -575,8 +575,8 @@ def test_small_model_integration_test_batch_different_resolutions(self): output = model.generate(**inputs, max_new_tokens=30, do_sample=False) EXPECTED_DECODED_TEXT = [ - "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a wild cat species native to the grasslands and steppes", - "user\nWhat kind of dog is this?\nassistant\nBased on the image provided, the animals are not dogs. They are two cats.\n\nHere is a description of the animals in the image:\n\n- " + "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a wild cat species native to the grasslands and montane regions", + "user\nWhat kind of dog is this?\nassistant\nBased on the image provided, there is no dog present. The animals in the picture are two cats.\n\nHere is a description of the scene:\n-" ] # fmt: skip self.assertEqual( self.processor.batch_decode(output, skip_special_tokens=True), diff --git a/tests/models/video_llama_3/test_modeling_video_llama_3.py b/tests/models/video_llama_3/test_modeling_video_llama_3.py index 72f6ed865c48..e74bc4b51604 100644 --- a/tests/models/video_llama_3/test_modeling_video_llama_3.py +++ b/tests/models/video_llama_3/test_modeling_video_llama_3.py @@ -883,8 +883,8 @@ def test_small_model_integration_test_batch_wo_image(self): EXPECTED_DECODED_TEXT = Expectations( { ("cuda", None): [ - 'user\n\nDescribe the image.\nassistant\nThe image captures a vibrant night scene in a bustling Japanese city. A woman in a striking red dress', - 'user\nWhat is relativity?\nassistant\nRelativity is a scientific theory that describes the relationship between space and time. It was first proposed by' + "user\n\nDescribe the image.\nassistant\nThe image captures a vibrant night scene in a bustling Japanese city. A woman in a striking red dress", + "user\nWhat is relativity?\nassistant\nRelativity is a scientific theory that describes the relationship between space and time. It was first proposed by", ], ("xpu", None): [ "user\n\nDescribe the image.\nassistant\nThe image captures a vibrant nighttime scene on a bustling city street. A woman in a striking red dress", @@ -918,8 +918,8 @@ def test_small_model_integration_test_batch_different_resolutions(self): EXPECTED_DECODED_TEXT = Expectations( { ("cuda", None): [ - 'user\n\nDescribe the image.\nassistant\nThe image captures a vibrant night scene in a bustling Japanese city. A woman in a striking red dress', - 'user\n\nDescribe the image.\nassistant\nThe image depicts a striking urban scene at night. A person is standing in the center of a wet' + "user\n\nDescribe the image.\nassistant\nThe image captures a vibrant night scene in a bustling Japanese city. A woman in a striking red dress", + "user\n\nDescribe the image.\nassistant\nThe image depicts a striking urban scene at night. A person is standing in the center of a wet", ], ("xpu", None): [ "user\n\nDescribe the image.\nassistant\nThe image captures a vibrant night scene in a bustling Japanese city. A woman in a striking red dress",