Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
05d2d21
Extract pure vision/audio functions into standalone utilities
IlyasMoutawwakil Apr 13, 2026
fe46ba2
Fix stale compute_ docstring references to match actual function names
IlyasMoutawwakil Apr 13, 2026
84439a0
Revert mlcd changes — not part of this PR
IlyasMoutawwakil Apr 13, 2026
e62aa98
fix
IlyasMoutawwakil Apr 13, 2026
cbc1e22
Merge branch 'main' into hf-vision-audio-utils
IlyasMoutawwakil Apr 13, 2026
c1d7a8a
kwargs
IlyasMoutawwakil Apr 13, 2026
2771799
opt-in
IlyasMoutawwakil Apr 13, 2026
fa224e2
fix dtype
IlyasMoutawwakil Apr 13, 2026
ac2895d
style
IlyasMoutawwakil Apr 13, 2026
2f2787c
guard torch import
IlyasMoutawwakil Apr 13, 2026
d628d96
standarize
IlyasMoutawwakil Apr 13, 2026
2a014a4
propagate inputs
IlyasMoutawwakil Apr 13, 2026
957372a
fix docs
IlyasMoutawwakil Apr 13, 2026
4194ff1
fix docs
IlyasMoutawwakil Apr 13, 2026
836424b
auto docs
IlyasMoutawwakil Apr 13, 2026
11f73fd
more docs fixing
IlyasMoutawwakil Apr 13, 2026
71f90ec
fix omni
IlyasMoutawwakil Apr 13, 2026
a89d436
fix paddle
IlyasMoutawwakil Apr 13, 2026
c0fdc0d
revert paddle ocr until another time
IlyasMoutawwakil Apr 13, 2026
d1da022
finally fixed paddle ocr
IlyasMoutawwakil Apr 13, 2026
448ff2e
fix review
IlyasMoutawwakil Apr 13, 2026
6731028
revert chunking
IlyasMoutawwakil Apr 13, 2026
693ba9c
Potential fix for pull request finding
IlyasMoutawwakil Apr 13, 2026
d701016
Potential fix for pull request finding
IlyasMoutawwakil Apr 13, 2026
5472c4f
fix torch compilable check
IlyasMoutawwakil Apr 13, 2026
12a416c
Merge branch 'hf-vision-audio-utils' of https://github.com/huggingfac…
IlyasMoutawwakil Apr 13, 2026
4e7739b
fix docs
IlyasMoutawwakil Apr 13, 2026
47fed92
correct func name
IlyasMoutawwakil Apr 13, 2026
18a1788
fix omni
IlyasMoutawwakil Apr 13, 2026
4c6e1df
fix video llama 3
IlyasMoutawwakil Apr 13, 2026
247b445
fix video llama 3
IlyasMoutawwakil Apr 13, 2026
3c5e9a8
requires torch
IlyasMoutawwakil Apr 14, 2026
27677ed
add missing grid device
IlyasMoutawwakil Apr 14, 2026
45a03e4
keep rot emb in fp32
IlyasMoutawwakil Apr 14, 2026
5f3d2ae
fix test device
IlyasMoutawwakil Apr 14, 2026
1feb220
fix flm4v flex attention test
IlyasMoutawwakil Apr 15, 2026
e4c4138
rename to vision utils
IlyasMoutawwakil Apr 15, 2026
d401a33
only one get_rotary_pos_ids is needed
IlyasMoutawwakil Apr 15, 2026
fc49a3f
style
IlyasMoutawwakil Apr 15, 2026
fc4bf66
Merge branch 'main' into hf-vision-audio-utils
IlyasMoutawwakil Apr 15, 2026
4711af6
style
IlyasMoutawwakil Apr 15, 2026
e85551e
deprecate only
IlyasMoutawwakil Apr 20, 2026
531f13c
fix
IlyasMoutawwakil Apr 20, 2026
4c3d84d
simplify and revert processor changes
IlyasMoutawwakil Apr 20, 2026
9ea6203
renames
IlyasMoutawwakil Apr 21, 2026
67b0906
move some stuff to their original place
IlyasMoutawwakil Apr 21, 2026
b8323fb
style
IlyasMoutawwakil Apr 21, 2026
205e94d
Merge branch 'main' into hf-vision-audio-utils
IlyasMoutawwakil Apr 21, 2026
a6b071f
style
IlyasMoutawwakil Apr 21, 2026
5dcc3ea
Merge branch 'main' into hf-vision-audio-utils
IlyasMoutawwakil Apr 22, 2026
e9ac058
use chunked attention
IlyasMoutawwakil Apr 22, 2026
a7c2277
use decorator
IlyasMoutawwakil Apr 27, 2026
85556ab
Merge branch 'main' into hf-vision-audio-utils
IlyasMoutawwakil Apr 27, 2026
6d33d4a
pass kwargs and return_dict
IlyasMoutawwakil Apr 27, 2026
fe3bcc4
fix missing
IlyasMoutawwakil Apr 28, 2026
51f7e20
keep in and get from kwargs
IlyasMoutawwakil Apr 28, 2026
4838e17
revert some trailing commas
IlyasMoutawwakil Apr 28, 2026
4246a72
fix
IlyasMoutawwakil Apr 28, 2026
62d901c
fixes
IlyasMoutawwakil Apr 29, 2026
1481de6
video llama fixes
IlyasMoutawwakil Apr 29, 2026
3538f47
fix qwen3 vl
IlyasMoutawwakil Apr 29, 2026
2abfe3a
Merge branch 'main' into hf-vision-audio-utils
IlyasMoutawwakil Apr 29, 2026
e44ec7b
forgot glm ocr
IlyasMoutawwakil Apr 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 31 additions & 48 deletions src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# limitations under the License.

import itertools
import warnings
from collections.abc import Callable
from typing import Any, Optional

Expand All @@ -39,8 +40,14 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check
from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults
from ...utils.generic import (
handle_extra_kwargs,
is_flash_attention_requested,
maybe_autocast,
merge_with_config_defaults,
)
from ...utils.output_capturing import OutputRecorder, capture_outputs
from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids
from .configuration_ernie4_5_vl_moe import Ernie4_5_VLMoeConfig, Ernie4_5_VLMoeTextConfig, Ernie4_5_VLMoeVisionConfig


Expand Down Expand Up @@ -855,10 +862,8 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None:
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

def forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
return freqs
def forward(self, position_ids: torch.Tensor) -> torch.Tensor:
return (position_ids.unsqueeze(-1) * self.inv_freq).flatten(1)


@auto_docstring
Expand Down Expand Up @@ -894,32 +899,13 @@ def __init__(self, config) -> None:
self.post_init()

def rot_pos_emb(self, grid_thw):
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()

wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
warnings.warn(
f"`{self.__class__.__name__}.rot_pos_emb` is deprecated and will be removed in a future version. Use `get_vision_position_ids` from `transformers.vision_utils` and apply the rotary embedding module.",
FutureWarning,
stacklevel=2,
)
position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size)
rotary_pos_emb = self.rotary_pos_emb(position_ids)
return rotary_pos_emb

@merge_with_config_defaults
Expand All @@ -931,21 +917,14 @@ def forward(
grid_thw (`torch.LongTensor` of shape `(num_images, 3)`):
The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values.
"""
position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs)
cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs)

hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
rotary_pos_emb = self.rotary_pos_emb(position_ids)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())

cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0,
# Select dtype based on the following factors:
# - FA2 requires that cu_seqlens_q must have dtype int32
# - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
# See https://github.com/huggingface/transformers/pull/34852 for more information
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

for block in self.blocks:
hidden_states = block(
hidden_states,
Expand Down Expand Up @@ -1263,6 +1242,7 @@ def get_rope_index(
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas

@handle_extra_kwargs(modality="video")
@can_return_tuple
@auto_docstring
def get_video_features(
Expand All @@ -1288,6 +1268,7 @@ def get_video_features(
video_outputs.pooler_output = video_embeds
return video_outputs

@handle_extra_kwargs(modality="image")
@can_return_tuple
@auto_docstring
def get_image_features(
Expand Down Expand Up @@ -1434,15 +1415,19 @@ def forward(
inputs_embeds = self.get_input_embeddings()(input_ids)

if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output
image_embeds = self.get_image_features(
pixel_values, image_grid_thw, return_dict=True, **kwargs
).pooler_output
image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
image_mask, _ = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output
video_embeds = self.get_video_features(
pixel_values_videos, video_grid_thw, return_dict=True, **kwargs
).pooler_output
video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
_, video_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
Expand Down Expand Up @@ -1598,9 +1583,7 @@ def get_video_features(
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
"""
return self.model.get_video_features(
pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs
)
return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs)

@auto_docstring
def get_image_features(
Expand All @@ -1615,7 +1598,7 @@ def get_image_features(
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
"""
return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs)
return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs)

@auto_docstring
@can_return_tuple
Expand Down
29 changes: 14 additions & 15 deletions src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub.dataclasses import strict
from torchvision.transforms.v2 import functional as tvF

Expand Down Expand Up @@ -51,8 +50,9 @@
can_return_tuple,
logging,
)
from ...utils.generic import maybe_autocast, merge_with_config_defaults
from ...utils.generic import handle_extra_kwargs, maybe_autocast, merge_with_config_defaults
from ...utils.output_capturing import OutputRecorder, capture_outputs
from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids
from ..ernie4_5_moe.configuration_ernie4_5_moe import Ernie4_5_MoeConfig
from ..ernie4_5_moe.modeling_ernie4_5_moe import (
Ernie4_5_MoeAttention,
Expand Down Expand Up @@ -701,21 +701,14 @@ def get_device(self):
def forward(
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
) -> tuple | BaseModelOutputWithPooling:
position_ids = get_vision_position_ids(grid_thw, self.spatial_merge_size, kwargs=kwargs)
cu_seqlens = get_vision_cu_seqlens(grid_thw, kwargs=kwargs)

hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
rotary_pos_emb = self.rotary_pos_emb(position_ids)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())

cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0,
# Select dtype based on the following factors:
# - FA2 requires that cu_seqlens_q must have dtype int32
# - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
# See https://github.com/huggingface/transformers/pull/34852 for more information
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

for block in self.blocks:
hidden_states = block(
hidden_states,
Expand Down Expand Up @@ -962,6 +955,7 @@ def get_rope_index(
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas

@handle_extra_kwargs(modality="video")
@can_return_tuple
@auto_docstring
def get_video_features(
Expand All @@ -981,6 +975,7 @@ def get_video_features(
video_outputs.pooler_output = video_embeds
return video_outputs

@handle_extra_kwargs(modality="image")
@can_return_tuple
@auto_docstring
def get_image_features(
Expand Down Expand Up @@ -1031,15 +1026,19 @@ def forward(
inputs_embeds = self.get_input_embeddings()(input_ids)

if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output
image_embeds = self.get_image_features(
pixel_values, image_grid_thw, return_dict=True, **kwargs
).pooler_output
image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
image_mask, _ = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output
video_embeds = self.get_video_features(
pixel_values_videos, video_grid_thw, return_dict=True, **kwargs
).pooler_output
video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
_, video_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
Expand Down
33 changes: 20 additions & 13 deletions src/transformers/models/glm46v/modeling_glm46v.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
can_return_tuple,
torch_compilable_check,
)
from ...utils.generic import handle_extra_kwargs
from ..auto import AutoModel
from .configuration_glm46v import Glm46VConfig

Expand Down Expand Up @@ -254,6 +255,7 @@ def get_rope_index(
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas

@handle_extra_kwargs(modality="video")
@can_return_tuple
@auto_docstring
def get_video_features(
Expand All @@ -270,12 +272,12 @@ def get_video_features(
"""
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
# reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
temp_frames_hw = []
video_grid_thw_list = video_grid_thw.tolist()
for t, h, w in video_grid_thw_list:
repeated_row = torch.tensor([1, h, w]).unsqueeze(0).repeat(t, 1)
temp_frames_hw.append(repeated_row)
flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
t = video_grid_thw[:, 0]
hw = video_grid_thw[:, 1:]
# repeat each (h,w) row `t` times
flattened_hw = torch.repeat_interleave(hw, t, dim=0)
prefix_ones = video_grid_thw.new_ones(flattened_hw.shape[0], 1)
flattened_video_grid_thw = torch.cat([prefix_ones, flattened_hw], dim=1)
vision_outputs = self.visual(
pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs
)
Expand All @@ -285,6 +287,7 @@ def get_video_features(

return vision_outputs

@handle_extra_kwargs(modality="image")
@can_return_tuple
@auto_docstring
def get_image_features(
Expand All @@ -300,7 +303,7 @@ def get_image_features(
The temporal, height and width of feature shape of each image in LLM.
"""
pixel_values = pixel_values.type(self.visual.dtype)
vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs)
vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs)
split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
image_embeds = torch.split(vision_outputs.pooler_output, split_sizes)
vision_outputs.pooler_output = image_embeds
Expand Down Expand Up @@ -430,13 +433,17 @@ def forward(
inputs_embeds = self.get_input_embeddings()(input_ids)

if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output
image_embeds = self.get_image_features(
pixel_values, image_grid_thw, return_dict=True, **kwargs
).pooler_output
image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output
video_embeds = self.get_video_features(
pixel_values_videos, video_grid_thw, return_dict=True, **kwargs
).pooler_output
video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
_, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
Expand Down Expand Up @@ -514,6 +521,7 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)

@handle_extra_kwargs(modality="video")
@auto_docstring
def get_video_features(
self,
Expand All @@ -527,10 +535,9 @@ def get_video_features(
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
"""
return self.model.get_video_features(
pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs
)
return self.model.get_video_features(pixel_values_videos, video_grid_thw, **kwargs)

@handle_extra_kwargs(modality="image")
@auto_docstring
def get_image_features(
self,
Expand All @@ -544,7 +551,7 @@ def get_image_features(
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
"""
return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs)
return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs)

@can_return_tuple
@auto_docstring
Expand Down
Loading
Loading